package mock import ( "context" "encoding/json" "fmt" "net/http" "sync" "time" "github.com/gorilla/websocket" log "github.com/sirupsen/logrus" ) var upgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true // Allow connections from any origin }, } // MockWebSocketServer manages WebSocket connections and mock broadcasts type MockWebSocketServer struct { discovery *MockNodeDiscovery clients map[*websocket.Conn]bool mutex sync.RWMutex writeMutex sync.Mutex shutdownChan chan struct{} shutdownOnce sync.Once logger *log.Logger } // NewMockWebSocketServer creates a new mock WebSocket server func NewMockWebSocketServer(discovery *MockNodeDiscovery) *MockWebSocketServer { mws := &MockWebSocketServer{ discovery: discovery, clients: make(map[*websocket.Conn]bool), shutdownChan: make(chan struct{}), logger: log.New(), } // Register callback for node updates discovery.AddCallback(mws.handleNodeUpdate) // Start periodic broadcasts go mws.startPeriodicBroadcasts() return mws } // HandleWebSocket handles WebSocket upgrade and connection func (mws *MockWebSocketServer) HandleWebSocket(w http.ResponseWriter, r *http.Request) error { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { mws.logger.WithError(err).Error("Failed to upgrade WebSocket connection") return err } mws.mutex.Lock() mws.clients[conn] = true mws.mutex.Unlock() mws.logger.Debug("Mock WebSocket client connected") // Send current cluster state to newly connected client go mws.sendCurrentClusterState(conn) // Handle client messages and disconnection go mws.handleClient(conn) return nil } // handleClient handles messages from a WebSocket client func (mws *MockWebSocketServer) handleClient(conn *websocket.Conn) { defer func() { mws.mutex.Lock() delete(mws.clients, conn) mws.mutex.Unlock() conn.Close() mws.logger.Debug("Mock WebSocket client disconnected") }() // Set read deadline and pong handler conn.SetReadDeadline(time.Now().Add(60 * time.Second)) conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(60 * time.Second)) return nil }) // Start ping routine go func() { ticker := time.NewTicker(54 * time.Second) defer ticker.Stop() for { select { case <-mws.shutdownChan: return case <-ticker.C: conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { return } } } }() // Read messages for { _, _, err := conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { mws.logger.WithError(err).Error("WebSocket error") } break } } } // sendCurrentClusterState sends the current cluster state to a newly connected client func (mws *MockWebSocketServer) sendCurrentClusterState(conn *websocket.Conn) { nodes := mws.discovery.GetNodes() if len(nodes) == 0 { return } // Convert to mock format mockNodes := make(map[string]*NodeInfo) for ip, node := range nodes { mockNodes[ip] = &NodeInfo{ IP: node.IP, Hostname: node.Hostname, Status: string(node.Status), Latency: node.Latency, LastSeen: node.LastSeen, Labels: node.Labels, } } members := GenerateMockClusterMembers(mockNodes) message := struct { Type string `json:"type"` Members interface{} `json:"members"` PrimaryNode string `json:"primaryNode"` TotalNodes int `json:"totalNodes"` Timestamp string `json:"timestamp"` }{ Type: "cluster_update", Members: members, PrimaryNode: mws.discovery.GetPrimaryNode(), TotalNodes: len(nodes), Timestamp: time.Now().Format(time.RFC3339), } data, err := json.Marshal(message) if err != nil { mws.logger.WithError(err).Error("Failed to marshal cluster data") return } conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) if err := conn.WriteMessage(websocket.TextMessage, data); err != nil { mws.logger.WithError(err).Error("Failed to send initial cluster state") } } // handleNodeUpdate is called when node information changes func (mws *MockWebSocketServer) handleNodeUpdate(nodeIP, action string) { mws.logger.WithFields(log.Fields{ "node_ip": nodeIP, "action": action, }).Debug("Mock node update received, broadcasting to WebSocket clients") // Broadcast cluster update mws.BroadcastClusterUpdate() // Also broadcast node discovery event mws.broadcastNodeDiscovery(nodeIP, action) } // startPeriodicBroadcasts sends periodic updates to keep clients informed func (mws *MockWebSocketServer) startPeriodicBroadcasts() { ticker := time.NewTicker(30 * time.Second) defer ticker.Stop() for { select { case <-mws.shutdownChan: return case <-ticker.C: mws.BroadcastClusterUpdate() } } } // BroadcastClusterUpdate sends cluster updates to all connected clients func (mws *MockWebSocketServer) BroadcastClusterUpdate() { mws.mutex.RLock() clients := make([]*websocket.Conn, 0, len(mws.clients)) for client := range mws.clients { clients = append(clients, client) } mws.mutex.RUnlock() if len(clients) == 0 { return } nodes := mws.discovery.GetNodes() // Convert to mock format mockNodes := make(map[string]*NodeInfo) for ip, node := range nodes { mockNodes[ip] = &NodeInfo{ IP: node.IP, Hostname: node.Hostname, Status: string(node.Status), Latency: node.Latency, LastSeen: node.LastSeen, Labels: node.Labels, } } members := GenerateMockClusterMembers(mockNodes) message := struct { Type string `json:"type"` Members interface{} `json:"members"` PrimaryNode string `json:"primaryNode"` TotalNodes int `json:"totalNodes"` Timestamp string `json:"timestamp"` }{ Type: "cluster_update", Members: members, PrimaryNode: mws.discovery.GetPrimaryNode(), TotalNodes: len(nodes), Timestamp: time.Now().Format(time.RFC3339), } data, err := json.Marshal(message) if err != nil { mws.logger.WithError(err).Error("Failed to marshal cluster update") return } mws.logger.WithField("clients", len(clients)).Debug("Broadcasting mock cluster update") // Send to all clients with write synchronization mws.writeMutex.Lock() defer mws.writeMutex.Unlock() for _, client := range clients { client.SetWriteDeadline(time.Now().Add(5 * time.Second)) if err := client.WriteMessage(websocket.TextMessage, data); err != nil { mws.logger.WithError(err).Error("Failed to send cluster update to client") } } } // broadcastNodeDiscovery sends node discovery events to all clients func (mws *MockWebSocketServer) broadcastNodeDiscovery(nodeIP, action string) { mws.mutex.RLock() clients := make([]*websocket.Conn, 0, len(mws.clients)) for client := range mws.clients { clients = append(clients, client) } mws.mutex.RUnlock() if len(clients) == 0 { return } message := struct { Type string `json:"type"` Action string `json:"action"` NodeIP string `json:"nodeIp"` Timestamp string `json:"timestamp"` }{ Type: "node_discovery", Action: action, NodeIP: nodeIP, Timestamp: time.Now().Format(time.RFC3339), } data, err := json.Marshal(message) if err != nil { mws.logger.WithError(err).Error("Failed to marshal node discovery event") return } // Send to all clients with write synchronization mws.writeMutex.Lock() defer mws.writeMutex.Unlock() for _, client := range clients { client.SetWriteDeadline(time.Now().Add(5 * time.Second)) if err := client.WriteMessage(websocket.TextMessage, data); err != nil { mws.logger.WithError(err).Error("Failed to send node discovery event to client") } } } // BroadcastFirmwareUploadStatus sends firmware upload status updates to all clients func (mws *MockWebSocketServer) BroadcastFirmwareUploadStatus(nodeIP, status, filename string, fileSize int) { mws.mutex.RLock() clients := make([]*websocket.Conn, 0, len(mws.clients)) for client := range mws.clients { clients = append(clients, client) } mws.mutex.RUnlock() if len(clients) == 0 { return } message := struct { Type string `json:"type"` NodeIP string `json:"nodeIp"` Status string `json:"status"` Filename string `json:"filename"` FileSize int `json:"fileSize"` Timestamp string `json:"timestamp"` }{ Type: "firmware_upload_status", NodeIP: nodeIP, Status: status, Filename: filename, FileSize: fileSize, Timestamp: time.Now().Format(time.RFC3339), } data, err := json.Marshal(message) if err != nil { mws.logger.WithError(err).Error("Failed to marshal firmware upload status") return } mws.logger.WithFields(log.Fields{ "node_ip": nodeIP, "status": status, "clients": len(clients), }).Debug("Broadcasting mock firmware upload status") // Send to all clients with write synchronization mws.writeMutex.Lock() defer mws.writeMutex.Unlock() for _, client := range clients { client.SetWriteDeadline(time.Now().Add(5 * time.Second)) if err := client.WriteMessage(websocket.TextMessage, data); err != nil { mws.logger.WithError(err).Error("Failed to send firmware upload status to client") } } } // BroadcastRolloutProgress sends rollout progress updates to all clients func (mws *MockWebSocketServer) BroadcastRolloutProgress(rolloutID, nodeIP, status string, current, total int) { mws.mutex.RLock() clients := make([]*websocket.Conn, 0, len(mws.clients)) for client := range mws.clients { clients = append(clients, client) } mws.mutex.RUnlock() if len(clients) == 0 { return } message := struct { Type string `json:"type"` RolloutID string `json:"rolloutId"` NodeIP string `json:"nodeIp"` Status string `json:"status"` Current int `json:"current"` Total int `json:"total"` Progress int `json:"progress"` Timestamp string `json:"timestamp"` }{ Type: "rollout_progress", RolloutID: rolloutID, NodeIP: nodeIP, Status: status, Current: current, Total: total, Progress: calculateProgress(current, total, status), Timestamp: time.Now().Format(time.RFC3339), } data, err := json.Marshal(message) if err != nil { mws.logger.WithError(err).Error("Failed to marshal rollout progress") return } mws.logger.WithFields(log.Fields{ "rollout_id": rolloutID, "node_ip": nodeIP, "status": status, "progress": fmt.Sprintf("%d/%d", current, total), }).Debug("Broadcasting mock rollout progress") // Send to all clients with write synchronization mws.writeMutex.Lock() defer mws.writeMutex.Unlock() for _, client := range clients { client.SetWriteDeadline(time.Now().Add(5 * time.Second)) if err := client.WriteMessage(websocket.TextMessage, data); err != nil { mws.logger.WithError(err).Error("Failed to send rollout progress to client") } } } // calculateProgress calculates the correct progress percentage based on current status func calculateProgress(current, total int, status string) int { if total == 0 { return 0 } // Base progress is based on completed nodes completedNodes := current - 1 if status == "completed" { completedNodes = current } // Calculate base progress (completed nodes / total nodes) baseProgress := float64(completedNodes) / float64(total) * 100 // If currently updating labels or uploading, add partial progress for the current node if status == "updating_labels" { nodeProgress := 100.0 / float64(total) * 0.25 baseProgress += nodeProgress } else if status == "uploading" { nodeProgress := 100.0 / float64(total) * 0.5 baseProgress += nodeProgress } // Ensure we don't exceed 100% if baseProgress > 100 { baseProgress = 100 } return int(baseProgress) } // GetClientCount returns the number of connected WebSocket clients func (mws *MockWebSocketServer) GetClientCount() int { mws.mutex.RLock() defer mws.mutex.RUnlock() return len(mws.clients) } // Shutdown gracefully shuts down the WebSocket server func (mws *MockWebSocketServer) Shutdown(ctx context.Context) error { mws.shutdownOnce.Do(func() { mws.logger.Info("Shutting down mock WebSocket server") close(mws.shutdownChan) mws.mutex.Lock() clients := make([]*websocket.Conn, 0, len(mws.clients)) for client := range mws.clients { clients = append(clients, client) } mws.mutex.Unlock() // Close all client connections for _, client := range clients { client.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, "Server shutting down")) client.Close() } }) return nil }