diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go index c57cd43..f7e29c2 100644 --- a/internal/discovery/discovery.go +++ b/internal/discovery/discovery.go @@ -49,6 +49,9 @@ func (nd *NodeDiscovery) Shutdown(ctx context.Context) error { return nil } +// MessageHandler processes a specific UDP message type +type MessageHandler func(payload string, remoteAddr *net.UDPAddr) + // handleUDPMessage processes incoming UDP messages func (nd *NodeDiscovery) handleUDPMessage(message string, remoteAddr *net.UDPAddr) { nd.logger.WithFields(log.Fields{ @@ -58,13 +61,39 @@ func (nd *NodeDiscovery) handleUDPMessage(message string, remoteAddr *net.UDPAdd message = strings.TrimSpace(message) - if strings.HasPrefix(message, "CLUSTER_HEARTBEAT:") { - hostname := strings.TrimPrefix(message, "CLUSTER_HEARTBEAT:") - nd.updateNodeFromHeartbeat(remoteAddr.IP.String(), remoteAddr.Port, hostname) - } else if strings.HasPrefix(message, "NODE_UPDATE:") { - nd.handleNodeUpdate(remoteAddr.IP.String(), message) - } else if !strings.HasPrefix(message, "RAW:") { - nd.logger.WithField("message", message).Debug("Received unknown UDP message") + // Extract topic by splitting on first ":" + parts := strings.SplitN(message, ":", 2) + if len(parts) < 2 { + nd.logger.WithField("message", message).Debug("Invalid message format - missing ':' separator") + return + } + + topic := parts[0] + payload := parts[1] + + // Handler map for different message types + handlers := map[string]MessageHandler{ + "cluster/heartbeat": func(payload string, remoteAddr *net.UDPAddr) { + nd.updateNodeFromHeartbeat(remoteAddr.IP.String(), remoteAddr.Port, payload) + }, + "node/update": func(payload string, remoteAddr *net.UDPAddr) { + // Reconstruct full message for handleNodeUpdate which expects "node/update:hostname:{json}" + fullMessage := "node/update:" + payload + nd.handleNodeUpdate(remoteAddr.IP.String(), fullMessage) + }, + "raw": func(payload string, remoteAddr *net.UDPAddr) { + nd.logger.WithField("message", "raw:"+payload).Debug("Received raw message") + }, + "cluster/event": func(payload string, remoteAddr *net.UDPAddr) { + nd.logger.WithField("message", "cluster/event:"+payload).Debug("Received cluster/event message") + }, + } + + // Look up and execute handler + if handler, exists := handlers[topic]; exists { + handler(payload, remoteAddr) + } else { + nd.logger.WithField("topic", topic).Debug("Received unknown UDP message type") } } @@ -138,9 +167,9 @@ func (nd *NodeDiscovery) updateNodeFromHeartbeat(sourceIP string, sourcePort int } } -// handleNodeUpdate processes NODE_UPDATE messages +// handleNodeUpdate processes NODE_UPDATE and node/update messages func (nd *NodeDiscovery) handleNodeUpdate(sourceIP, message string) { - // Message format: "NODE_UPDATE:hostname:{json}" + // Message format: "NODE_UPDATE:hostname:{json}" or "node/update:hostname:{json}" parts := strings.SplitN(message, ":", 3) if len(parts) < 3 { nd.logger.WithField("message", message).Warn("Invalid NODE_UPDATE message format") diff --git a/internal/websocket/websocket.go b/internal/websocket/websocket.go index e0e98a2..5f099e4 100644 --- a/internal/websocket/websocket.go +++ b/internal/websocket/websocket.go @@ -134,13 +134,13 @@ func (wss *WebSocketServer) sendCurrentClusterState(conn *websocket.Conn) { } message := struct { - Type string `json:"type"` + Topic string `json:"topic"` Members []client.ClusterMember `json:"members"` PrimaryNode string `json:"primaryNode"` TotalNodes int `json:"totalNodes"` Timestamp string `json:"timestamp"` }{ - Type: "cluster_update", + Topic: "cluster/update", Members: clusterData, PrimaryNode: wss.nodeDiscovery.GetPrimaryNode(), TotalNodes: len(nodes), @@ -227,13 +227,13 @@ func (wss *WebSocketServer) broadcastClusterUpdate() { } message := struct { - Type string `json:"type"` + Topic string `json:"topic"` Members []client.ClusterMember `json:"members"` PrimaryNode string `json:"primaryNode"` TotalNodes int `json:"totalNodes"` Timestamp string `json:"timestamp"` }{ - Type: "cluster_update", + Topic: "cluster/update", Members: clusterData, PrimaryNode: wss.nodeDiscovery.GetPrimaryNode(), TotalNodes: len(wss.nodeDiscovery.GetNodes()), @@ -287,12 +287,12 @@ func (wss *WebSocketServer) broadcastNodeDiscovery(nodeIP, action string) { } message := struct { - Type string `json:"type"` + Topic string `json:"topic"` Action string `json:"action"` NodeIP string `json:"nodeIp"` Timestamp string `json:"timestamp"` }{ - Type: "node_discovery", + Topic: "node/discovery", Action: action, NodeIP: nodeIP, Timestamp: time.Now().Format(time.RFC3339), @@ -330,14 +330,14 @@ func (wss *WebSocketServer) BroadcastFirmwareUploadStatus(nodeIP, status, filena } message := struct { - Type string `json:"type"` + Topic string `json:"topic"` NodeIP string `json:"nodeIp"` Status string `json:"status"` Filename string `json:"filename"` FileSize int `json:"fileSize"` Timestamp string `json:"timestamp"` }{ - Type: "firmware_upload_status", + Topic: "firmware/upload/status", NodeIP: nodeIP, Status: status, Filename: filename, @@ -385,7 +385,7 @@ func (wss *WebSocketServer) BroadcastRolloutProgress(rolloutID, nodeIP, status s } message := struct { - Type string `json:"type"` + Topic string `json:"topic"` RolloutID string `json:"rolloutId"` NodeIP string `json:"nodeIp"` Status string `json:"status"` @@ -394,7 +394,7 @@ func (wss *WebSocketServer) BroadcastRolloutProgress(rolloutID, nodeIP, status s Progress int `json:"progress"` Timestamp string `json:"timestamp"` }{ - Type: "rollout_progress", + Topic: "rollout/progress", RolloutID: rolloutID, NodeIP: nodeIP, Status: status,