package websocket import ( "context" "encoding/json" "net/http" "sync" "time" "spore-gateway/internal/discovery" "spore-gateway/pkg/client" "github.com/gorilla/websocket" log "github.com/sirupsen/logrus" ) var upgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true // Allow connections from any origin in development }, } // WebSocketServer manages WebSocket connections and broadcasts type WebSocketServer struct { nodeDiscovery *discovery.NodeDiscovery sporeClients map[string]*client.SporeClient clients map[*websocket.Conn]bool mutex sync.RWMutex writeMutex sync.Mutex // Mutex to serialize writes to WebSocket connections logger *log.Logger } // NewWebSocketServer creates a new WebSocket server func NewWebSocketServer(nodeDiscovery *discovery.NodeDiscovery) *WebSocketServer { wss := &WebSocketServer{ nodeDiscovery: nodeDiscovery, sporeClients: make(map[string]*client.SporeClient), clients: make(map[*websocket.Conn]bool), logger: log.New(), } // Register callback for node updates nodeDiscovery.AddCallback(wss.handleNodeUpdate) return wss } // HandleWebSocket handles WebSocket upgrade and connection func (wss *WebSocketServer) HandleWebSocket(w http.ResponseWriter, r *http.Request) error { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { wss.logger.WithError(err).Error("Failed to upgrade WebSocket connection") return err } wss.mutex.Lock() wss.clients[conn] = true wss.mutex.Unlock() wss.logger.Debug("WebSocket client connected") // Send current cluster state to newly connected client go wss.sendCurrentClusterState(conn) // Handle client messages and disconnection go wss.handleClient(conn) return nil } // handleClient handles messages from a WebSocket client func (wss *WebSocketServer) handleClient(conn *websocket.Conn) { defer func() { wss.mutex.Lock() delete(wss.clients, conn) wss.mutex.Unlock() conn.Close() wss.logger.Debug("WebSocket client disconnected") }() // Set read deadline and pong handler conn.SetReadDeadline(time.Now().Add(60 * time.Second)) conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(60 * time.Second)) return nil }) // Start ping routine go func() { ticker := time.NewTicker(54 * time.Second) defer ticker.Stop() for { select { case <-ticker.C: conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { return } } } }() // Read messages (we don't expect any, but this keeps the connection alive) for { _, _, err := conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { wss.logger.WithError(err).Error("WebSocket error") } break } } } // sendCurrentClusterState sends the current cluster state to a newly connected client func (wss *WebSocketServer) sendCurrentClusterState(conn *websocket.Conn) { nodes := wss.nodeDiscovery.GetNodes() if len(nodes) == 0 { return } // Get real cluster data from SPORE nodes clusterData, err := wss.getCurrentClusterMembers() if err != nil { wss.logger.WithError(err).Error("Failed to get cluster data for WebSocket") return } message := struct { Type string `json:"type"` Members []client.ClusterMember `json:"members"` PrimaryNode string `json:"primaryNode"` TotalNodes int `json:"totalNodes"` Timestamp string `json:"timestamp"` }{ Type: "cluster_update", Members: clusterData, PrimaryNode: wss.nodeDiscovery.GetPrimaryNode(), TotalNodes: len(nodes), Timestamp: time.Now().Format(time.RFC3339), } data, err := json.Marshal(message) if err != nil { wss.logger.WithError(err).Error("Failed to marshal cluster data") return } conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) if err := conn.WriteMessage(websocket.TextMessage, data); err != nil { wss.logger.WithError(err).Error("Failed to send initial cluster state") } } // handleNodeUpdate is called when node information changes func (wss *WebSocketServer) handleNodeUpdate(nodeIP, action string) { wss.logger.WithFields(log.Fields{ "node_ip": nodeIP, "action": action, }).Debug("Node update received, broadcasting to WebSocket clients") // Broadcast cluster update to all clients wss.broadcastClusterUpdate() // Also broadcast node discovery event wss.broadcastNodeDiscovery(nodeIP, action) } // broadcastClusterUpdate sends cluster updates to all connected clients func (wss *WebSocketServer) broadcastClusterUpdate() { wss.mutex.RLock() clients := make([]*websocket.Conn, 0, len(wss.clients)) for client := range wss.clients { clients = append(clients, client) } wss.mutex.RUnlock() if len(clients) == 0 { return } startTime := time.Now() // Get cluster members asynchronously clusterData, err := wss.getCurrentClusterMembers() if err != nil { wss.logger.WithError(err).Error("Failed to get cluster data for broadcast") return } message := struct { Type string `json:"type"` Members []client.ClusterMember `json:"members"` PrimaryNode string `json:"primaryNode"` TotalNodes int `json:"totalNodes"` Timestamp string `json:"timestamp"` }{ Type: "cluster_update", Members: clusterData, PrimaryNode: wss.nodeDiscovery.GetPrimaryNode(), TotalNodes: len(wss.nodeDiscovery.GetNodes()), Timestamp: time.Now().Format(time.RFC3339), } data, err := json.Marshal(message) if err != nil { wss.logger.WithError(err).Error("Failed to marshal cluster update") return } broadcastTime := time.Now() wss.logger.WithFields(log.Fields{ "clients": len(clients), "prep_time": broadcastTime.Sub(startTime), }).Debug("Broadcasting cluster update to WebSocket clients") // Send to all clients with write synchronization var failedClients int wss.writeMutex.Lock() defer wss.writeMutex.Unlock() for _, client := range clients { client.SetWriteDeadline(time.Now().Add(5 * time.Second)) if err := client.WriteMessage(websocket.TextMessage, data); err != nil { wss.logger.WithError(err).Error("Failed to send cluster update to client") failedClients++ } } totalTime := time.Since(startTime) wss.logger.WithFields(log.Fields{ "clients": len(clients), "failed_clients": failedClients, "total_time": totalTime, }).Debug("Cluster update broadcast completed") } // broadcastNodeDiscovery sends node discovery events to all clients func (wss *WebSocketServer) broadcastNodeDiscovery(nodeIP, action string) { wss.mutex.RLock() clients := make([]*websocket.Conn, 0, len(wss.clients)) for client := range wss.clients { clients = append(clients, client) } wss.mutex.RUnlock() if len(clients) == 0 { return } message := struct { Type string `json:"type"` Action string `json:"action"` NodeIP string `json:"nodeIp"` Timestamp string `json:"timestamp"` }{ Type: "node_discovery", Action: action, NodeIP: nodeIP, Timestamp: time.Now().Format(time.RFC3339), } data, err := json.Marshal(message) if err != nil { wss.logger.WithError(err).Error("Failed to marshal node discovery event") return } // Send to all clients with write synchronization wss.writeMutex.Lock() defer wss.writeMutex.Unlock() for _, client := range clients { client.SetWriteDeadline(time.Now().Add(5 * time.Second)) if err := client.WriteMessage(websocket.TextMessage, data); err != nil { wss.logger.WithError(err).Error("Failed to send node discovery event to client") } } } // getCurrentClusterMembers fetches real cluster data from SPORE nodes func (wss *WebSocketServer) getCurrentClusterMembers() ([]client.ClusterMember, error) { nodes := wss.nodeDiscovery.GetNodes() if len(nodes) == 0 { return []client.ClusterMember{}, nil } // Try to get real cluster data from primary node primaryNode := wss.nodeDiscovery.GetPrimaryNode() if primaryNode != "" { client := wss.getSporeClient(primaryNode) clusterStatus, err := client.GetClusterStatus() if err == nil { // Update local node data with API information wss.updateLocalNodesWithAPI(clusterStatus.Members) return clusterStatus.Members, nil } wss.logger.WithError(err).Error("Failed to get cluster status from primary node") } // Fallback to local data if API fails return wss.getFallbackClusterMembers(), nil } // updateLocalNodesWithAPI updates local node data with information from API func (wss *WebSocketServer) updateLocalNodesWithAPI(apiMembers []client.ClusterMember) { // This would update the local node discovery with fresh API data // For now, we'll just log that we received the data wss.logger.WithField("members", len(apiMembers)).Debug("Updating local nodes with API data") for _, member := range apiMembers { if len(member.Labels) > 0 { wss.logger.WithFields(log.Fields{ "ip": member.IP, "labels": member.Labels, }).Debug("API member labels") } } } // getFallbackClusterMembers returns local node data as fallback func (wss *WebSocketServer) getFallbackClusterMembers() []client.ClusterMember { nodes := wss.nodeDiscovery.GetNodes() members := make([]client.ClusterMember, 0, len(nodes)) for _, node := range nodes { member := client.ClusterMember{ IP: node.IP, Hostname: node.Hostname, Status: string(node.Status), Latency: node.Latency, LastSeen: node.LastSeen.Unix(), Labels: node.Labels, } members = append(members, member) } return members } // getSporeClient gets or creates a SPORE client for a node func (wss *WebSocketServer) getSporeClient(nodeIP string) *client.SporeClient { if client, exists := wss.sporeClients[nodeIP]; exists { return client } client := client.NewSporeClient("http://" + nodeIP) wss.sporeClients[nodeIP] = client return client } // GetClientCount returns the number of connected WebSocket clients func (wss *WebSocketServer) GetClientCount() int { wss.mutex.RLock() defer wss.mutex.RUnlock() return len(wss.clients) } // Shutdown gracefully shuts down the WebSocket server func (wss *WebSocketServer) Shutdown(ctx context.Context) error { wss.logger.Info("Shutting down WebSocket server") wss.mutex.Lock() clients := make([]*websocket.Conn, 0, len(wss.clients)) for client := range wss.clients { clients = append(clients, client) } wss.mutex.Unlock() // Close all client connections for _, client := range clients { client.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, "Server shutting down")) client.Close() } return nil }