package websocket import ( "context" "encoding/json" "fmt" "net/http" "sync" "time" "spore-gateway/internal/discovery" "spore-gateway/pkg/client" "github.com/gorilla/websocket" log "github.com/sirupsen/logrus" ) var upgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true // Allow connections from any origin in development }, } // WebSocketServer manages WebSocket connections and broadcasts type WebSocketServer struct { nodeDiscovery *discovery.NodeDiscovery sporeClients map[string]*client.SporeClient clients map[*websocket.Conn]bool mutex sync.RWMutex writeMutex sync.Mutex // Mutex to serialize writes to WebSocket connections logger *log.Logger clusterInfoTicker *time.Ticker clusterInfoStopCh chan bool clusterInfoInterval time.Duration } // NewWebSocketServer creates a new WebSocket server func NewWebSocketServer(nodeDiscovery *discovery.NodeDiscovery) *WebSocketServer { wss := &WebSocketServer{ nodeDiscovery: nodeDiscovery, sporeClients: make(map[string]*client.SporeClient), clients: make(map[*websocket.Conn]bool), logger: log.New(), clusterInfoStopCh: make(chan bool), clusterInfoInterval: 5 * time.Second, // Fetch cluster info every 5 seconds } // Register callback for node updates nodeDiscovery.AddCallback(wss.handleNodeUpdate) // Start periodic cluster info fetching go wss.startPeriodicClusterInfoFetching() return wss } // HandleWebSocket handles WebSocket upgrade and connection func (wss *WebSocketServer) HandleWebSocket(w http.ResponseWriter, r *http.Request) error { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { wss.logger.WithError(err).Error("Failed to upgrade WebSocket connection") return err } wss.mutex.Lock() wss.clients[conn] = true wss.mutex.Unlock() wss.logger.Debug("WebSocket client connected") // Send current cluster state to newly connected client go wss.sendCurrentClusterState(conn) // Handle client messages and disconnection go wss.handleClient(conn) return nil } // handleClient handles messages from a WebSocket client func (wss *WebSocketServer) handleClient(conn *websocket.Conn) { defer func() { wss.mutex.Lock() delete(wss.clients, conn) wss.mutex.Unlock() conn.Close() wss.logger.Debug("WebSocket client disconnected") }() // Set read deadline and pong handler conn.SetReadDeadline(time.Now().Add(60 * time.Second)) conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(60 * time.Second)) return nil }) // Start ping routine go func() { ticker := time.NewTicker(54 * time.Second) defer ticker.Stop() for range ticker.C { conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { return } } }() // Read messages (we don't expect any, but this keeps the connection alive) for { _, _, err := conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { wss.logger.WithError(err).Error("WebSocket error") } break } } } // sendCurrentClusterState sends the current cluster state to a newly connected client func (wss *WebSocketServer) sendCurrentClusterState(conn *websocket.Conn) { nodes := wss.nodeDiscovery.GetNodes() if len(nodes) == 0 { return } // Get real cluster data from SPORE nodes clusterData, err := wss.getCurrentClusterMembers() if err != nil { wss.logger.WithError(err).Error("Failed to get cluster data for WebSocket") return } message := struct { Topic string `json:"topic"` Members []client.ClusterMember `json:"members"` PrimaryNode string `json:"primaryNode"` TotalNodes int `json:"totalNodes"` Timestamp string `json:"timestamp"` }{ Topic: "cluster/update", Members: clusterData, PrimaryNode: wss.nodeDiscovery.GetPrimaryNode(), TotalNodes: len(nodes), Timestamp: time.Now().Format(time.RFC3339), } data, err := json.Marshal(message) if err != nil { wss.logger.WithError(err).Error("Failed to marshal cluster data") return } conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) if err := conn.WriteMessage(websocket.TextMessage, data); err != nil { wss.logger.WithError(err).Error("Failed to send initial cluster state") } } // startPeriodicClusterInfoFetching starts a goroutine that periodically fetches cluster info func (wss *WebSocketServer) startPeriodicClusterInfoFetching() { wss.clusterInfoTicker = time.NewTicker(wss.clusterInfoInterval) defer wss.clusterInfoTicker.Stop() wss.logger.WithField("interval", wss.clusterInfoInterval).Info("Starting periodic cluster info fetching") for { select { case <-wss.clusterInfoTicker.C: wss.fetchAndBroadcastClusterInfo() case <-wss.clusterInfoStopCh: wss.logger.Info("Stopping periodic cluster info fetching") return } } } // fetchAndBroadcastClusterInfo fetches cluster info and broadcasts it to clients func (wss *WebSocketServer) fetchAndBroadcastClusterInfo() { // Only fetch if we have clients connected wss.mutex.RLock() clientCount := len(wss.clients) wss.mutex.RUnlock() if clientCount == 0 { return } wss.logger.Debug("Periodically fetching cluster info") wss.broadcastClusterUpdate() } // handleNodeUpdate is called when node information changes func (wss *WebSocketServer) handleNodeUpdate(nodeIP, action string) { wss.logger.WithFields(log.Fields{ "node_ip": nodeIP, "action": action, }).Debug("Node update received, broadcasting node discovery event") // Only broadcast node discovery event, not cluster update // Cluster updates are now handled by periodic fetching wss.broadcastNodeDiscovery(nodeIP, action) } // broadcastClusterUpdate sends cluster updates to all connected clients func (wss *WebSocketServer) broadcastClusterUpdate() { wss.mutex.RLock() clients := make([]*websocket.Conn, 0, len(wss.clients)) for client := range wss.clients { clients = append(clients, client) } wss.mutex.RUnlock() if len(clients) == 0 { return } startTime := time.Now() // Get cluster members asynchronously clusterData, err := wss.getCurrentClusterMembers() if err != nil { wss.logger.WithError(err).Error("Failed to get cluster data for broadcast") return } message := struct { Topic string `json:"topic"` Members []client.ClusterMember `json:"members"` PrimaryNode string `json:"primaryNode"` TotalNodes int `json:"totalNodes"` Timestamp string `json:"timestamp"` }{ Topic: "cluster/update", Members: clusterData, PrimaryNode: wss.nodeDiscovery.GetPrimaryNode(), TotalNodes: len(wss.nodeDiscovery.GetNodes()), Timestamp: time.Now().Format(time.RFC3339), } data, err := json.Marshal(message) if err != nil { wss.logger.WithError(err).Error("Failed to marshal cluster update") return } broadcastTime := time.Now() wss.logger.WithFields(log.Fields{ "clients": len(clients), "prep_time": broadcastTime.Sub(startTime), }).Debug("Broadcasting cluster update to WebSocket clients") // Send to all clients with write synchronization var failedClients int wss.writeMutex.Lock() defer wss.writeMutex.Unlock() for _, client := range clients { client.SetWriteDeadline(time.Now().Add(5 * time.Second)) if err := client.WriteMessage(websocket.TextMessage, data); err != nil { wss.logger.WithError(err).Error("Failed to send cluster update to client") failedClients++ } } totalTime := time.Since(startTime) wss.logger.WithFields(log.Fields{ "clients": len(clients), "failed_clients": failedClients, "total_time": totalTime, }).Debug("Cluster update broadcast completed") } // broadcastNodeDiscovery sends node discovery events to all clients func (wss *WebSocketServer) broadcastNodeDiscovery(nodeIP, action string) { wss.mutex.RLock() clients := make([]*websocket.Conn, 0, len(wss.clients)) for client := range wss.clients { clients = append(clients, client) } wss.mutex.RUnlock() if len(clients) == 0 { return } message := struct { Topic string `json:"topic"` Action string `json:"action"` NodeIP string `json:"nodeIp"` Timestamp string `json:"timestamp"` }{ Topic: "node/discovery", Action: action, NodeIP: nodeIP, Timestamp: time.Now().Format(time.RFC3339), } data, err := json.Marshal(message) if err != nil { wss.logger.WithError(err).Error("Failed to marshal node discovery event") return } // Send to all clients with write synchronization wss.writeMutex.Lock() defer wss.writeMutex.Unlock() for _, client := range clients { client.SetWriteDeadline(time.Now().Add(5 * time.Second)) if err := client.WriteMessage(websocket.TextMessage, data); err != nil { wss.logger.WithError(err).Error("Failed to send node discovery event to client") } } } // BroadcastFirmwareUploadStatus sends firmware upload status updates to all clients func (wss *WebSocketServer) BroadcastFirmwareUploadStatus(nodeIP, status, filename string, fileSize int) { wss.mutex.RLock() clients := make([]*websocket.Conn, 0, len(wss.clients)) for client := range wss.clients { clients = append(clients, client) } wss.mutex.RUnlock() if len(clients) == 0 { return } message := struct { Topic string `json:"topic"` NodeIP string `json:"nodeIp"` Status string `json:"status"` Filename string `json:"filename"` FileSize int `json:"fileSize"` Timestamp string `json:"timestamp"` }{ Topic: "firmware/upload/status", NodeIP: nodeIP, Status: status, Filename: filename, FileSize: fileSize, Timestamp: time.Now().Format(time.RFC3339), } data, err := json.Marshal(message) if err != nil { wss.logger.WithError(err).Error("Failed to marshal firmware upload status") return } wss.logger.WithFields(log.Fields{ "node_ip": nodeIP, "status": status, "filename": filename, "file_size": fileSize, "clients": len(clients), }).Debug("Broadcasting firmware upload status to WebSocket clients") // Send to all clients with write synchronization wss.writeMutex.Lock() defer wss.writeMutex.Unlock() for _, client := range clients { client.SetWriteDeadline(time.Now().Add(5 * time.Second)) if err := client.WriteMessage(websocket.TextMessage, data); err != nil { wss.logger.WithError(err).Error("Failed to send firmware upload status to client") } } } // BroadcastRolloutProgress sends rollout progress updates to all clients func (wss *WebSocketServer) BroadcastRolloutProgress(rolloutID, nodeIP, status string, current, total int) { wss.mutex.RLock() clients := make([]*websocket.Conn, 0, len(wss.clients)) for client := range wss.clients { clients = append(clients, client) } wss.mutex.RUnlock() if len(clients) == 0 { return } message := struct { Topic string `json:"topic"` RolloutID string `json:"rolloutId"` NodeIP string `json:"nodeIp"` Status string `json:"status"` Current int `json:"current"` Total int `json:"total"` Progress int `json:"progress"` Timestamp string `json:"timestamp"` }{ Topic: "rollout/progress", RolloutID: rolloutID, NodeIP: nodeIP, Status: status, Current: current, Total: total, Progress: wss.calculateProgress(current, total, status), Timestamp: time.Now().Format(time.RFC3339), } data, err := json.Marshal(message) if err != nil { wss.logger.WithError(err).Error("Failed to marshal rollout progress") return } wss.logger.WithFields(log.Fields{ "rollout_id": rolloutID, "node_ip": nodeIP, "status": status, "progress": fmt.Sprintf("%d/%d", current, total), "clients": len(clients), }).Debug("Broadcasting rollout progress to WebSocket clients") // Send to all clients with write synchronization wss.writeMutex.Lock() defer wss.writeMutex.Unlock() for _, client := range clients { client.SetWriteDeadline(time.Now().Add(5 * time.Second)) if err := client.WriteMessage(websocket.TextMessage, data); err != nil { wss.logger.WithError(err).Error("Failed to send rollout progress to client") } } } // calculateProgress calculates the correct progress percentage based on current status func (wss *WebSocketServer) calculateProgress(current, total int, status string) int { if total == 0 { return 0 } // Base progress is based on completed nodes completedNodes := current - 1 if status == "completed" { completedNodes = current } // Calculate base progress (completed nodes / total nodes) baseProgress := float64(completedNodes) / float64(total) * 100 // If currently updating labels or uploading, add partial progress for the current node if status == "updating_labels" { // Add 25% of one node's progress (label update is quick) nodeProgress := 100.0 / float64(total) * 0.25 baseProgress += nodeProgress } else if status == "uploading" { // Add 50% of one node's progress (so uploading shows as halfway through that node) nodeProgress := 100.0 / float64(total) * 0.5 baseProgress += nodeProgress } // Ensure we don't exceed 100% if baseProgress > 100 { baseProgress = 100 } return int(baseProgress) } // getCurrentClusterMembers fetches real cluster data from SPORE nodes func (wss *WebSocketServer) getCurrentClusterMembers() ([]client.ClusterMember, error) { nodes := wss.nodeDiscovery.GetNodes() if len(nodes) == 0 { wss.logger.Debug("No nodes available for cluster member retrieval") return []client.ClusterMember{}, nil } // Try to get real cluster data from primary node primaryNode := wss.nodeDiscovery.GetPrimaryNode() if primaryNode != "" { wss.logger.WithFields(log.Fields{ "primary_node": primaryNode, "total_nodes": len(nodes), }).Debug("Fetching cluster members from primary node") client := wss.getSporeClient(primaryNode) clusterStatus, err := client.GetClusterStatus() if err == nil { wss.logger.WithFields(log.Fields{ "primary_node": primaryNode, "member_count": len(clusterStatus.Members), }).Debug("Successfully fetched cluster members from primary node") // Update local node data with API information but preserve heartbeat status wss.updateLocalNodesWithAPI(clusterStatus.Members) // Return merged data with heartbeat-based status override return wss.mergeAPIWithHeartbeatStatus(clusterStatus.Members), nil } wss.logger.WithFields(log.Fields{ "primary_node": primaryNode, "error": err.Error(), }).Debug("Failed to get cluster status from primary node, using fallback") } else { wss.logger.Debug("No primary node available, using fallback cluster members") } // Fallback to local data if API fails return wss.getFallbackClusterMembers(), nil } // updateLocalNodesWithAPI updates local node data with information from API func (wss *WebSocketServer) updateLocalNodesWithAPI(apiMembers []client.ClusterMember) { wss.logger.WithField("members", len(apiMembers)).Debug("Updating local nodes with API data") for _, member := range apiMembers { // Update local node with API data, but preserve heartbeat-based status wss.updateNodeWithAPIData(member) } } // updateNodeWithAPIData updates a single node with API data while preserving heartbeat status func (wss *WebSocketServer) updateNodeWithAPIData(apiMember client.ClusterMember) { nodes := wss.nodeDiscovery.GetNodes() if localNode, exists := nodes[apiMember.IP]; exists { // Update additional data from API but preserve heartbeat-based status localNode.Labels = apiMember.Labels localNode.Resources = apiMember.Resources localNode.Latency = apiMember.Latency // Only update hostname if it's different and not empty if apiMember.Hostname != "" && apiMember.Hostname != localNode.Hostname { localNode.Hostname = apiMember.Hostname } wss.logger.WithFields(log.Fields{ "ip": apiMember.IP, "labels": apiMember.Labels, "status": localNode.Status, // Keep heartbeat-based status }).Debug("Updated node with API data, preserved heartbeat status") } } // mergeAPIWithHeartbeatStatus merges API member data with heartbeat-based status func (wss *WebSocketServer) mergeAPIWithHeartbeatStatus(apiMembers []client.ClusterMember) []client.ClusterMember { localNodes := wss.nodeDiscovery.GetNodes() mergedMembers := make([]client.ClusterMember, 0, len(apiMembers)) for _, apiMember := range apiMembers { mergedMember := apiMember // Override status with heartbeat-based status if we have local data if localNode, exists := localNodes[apiMember.IP]; exists { mergedMember.Status = string(localNode.Status) mergedMember.LastSeen = localNode.LastSeen.Unix() wss.logger.WithFields(log.Fields{ "ip": apiMember.IP, "api_status": apiMember.Status, "heartbeat_status": localNode.Status, }).Debug("Overriding API status with heartbeat status") } mergedMembers = append(mergedMembers, mergedMember) } return mergedMembers } // getFallbackClusterMembers returns local node data as fallback func (wss *WebSocketServer) getFallbackClusterMembers() []client.ClusterMember { nodes := wss.nodeDiscovery.GetNodes() members := make([]client.ClusterMember, 0, len(nodes)) for _, node := range nodes { member := client.ClusterMember{ IP: node.IP, Hostname: node.Hostname, Status: string(node.Status), Latency: node.Latency, LastSeen: node.LastSeen.Unix(), Labels: node.Labels, } members = append(members, member) } return members } // getSporeClient gets or creates a SPORE client for a node func (wss *WebSocketServer) getSporeClient(nodeIP string) *client.SporeClient { if client, exists := wss.sporeClients[nodeIP]; exists { return client } client := client.NewSporeClient("http://" + nodeIP) wss.sporeClients[nodeIP] = client return client } // GetClientCount returns the number of connected WebSocket clients func (wss *WebSocketServer) GetClientCount() int { wss.mutex.RLock() defer wss.mutex.RUnlock() return len(wss.clients) } // BroadcastClusterEvent sends cluster events to all connected clients func (wss *WebSocketServer) BroadcastClusterEvent(topic string, data interface{}) { wss.mutex.RLock() clients := make([]*websocket.Conn, 0, len(wss.clients)) for client := range wss.clients { clients = append(clients, client) } wss.mutex.RUnlock() if len(clients) == 0 { return } message := struct { Topic string `json:"topic"` Data interface{} `json:"data"` Timestamp string `json:"timestamp"` }{ Topic: topic, Data: data, Timestamp: time.Now().Format(time.RFC3339), } messageData, err := json.Marshal(message) if err != nil { wss.logger.WithError(err).Error("Failed to marshal cluster event") return } wss.logger.WithFields(log.Fields{ "topic": topic, "clients": len(clients), }).Debug("Broadcasting cluster event to WebSocket clients") // Send to all clients with write synchronization wss.writeMutex.Lock() defer wss.writeMutex.Unlock() for _, client := range clients { client.SetWriteDeadline(time.Now().Add(5 * time.Second)) if err := client.WriteMessage(websocket.TextMessage, messageData); err != nil { wss.logger.WithError(err).Error("Failed to send cluster event to client") } } } // Shutdown gracefully shuts down the WebSocket server func (wss *WebSocketServer) Shutdown(ctx context.Context) error { wss.logger.Info("Shutting down WebSocket server") // Stop periodic cluster info fetching close(wss.clusterInfoStopCh) wss.mutex.Lock() clients := make([]*websocket.Conn, 0, len(wss.clients)) for client := range wss.clients { clients = append(clients, client) } wss.mutex.Unlock() // Close all client connections for _, client := range clients { client.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, "Server shutting down")) client.Close() } return nil }