package discovery import ( "context" "fmt" "net" "strconv" "strings" "time" log "github.com/sirupsen/logrus" ) // Start starts the UDP discovery server func (nd *NodeDiscovery) Start() error { addr := net.UDPAddr{ Port: parsePort(nd.udpPort), IP: net.ParseIP("0.0.0.0"), } conn, err := net.ListenUDP("udp", &addr) if err != nil { return fmt.Errorf("failed to listen on UDP port %s: %w", nd.udpPort, err) } nd.logger.WithField("port", nd.udpPort).Info("UDP heartbeat server listening") // Start cleanup routine go nd.startCleanupRoutine() // Handle incoming UDP messages buffer := make([]byte, 1024) for { n, remoteAddr, err := conn.ReadFromUDP(buffer) if err != nil { nd.logger.WithError(err).Error("Error reading UDP message") continue } message := string(buffer[:n]) nd.handleUDPMessage(message, remoteAddr) } } // Shutdown gracefully shuts down the discovery server func (nd *NodeDiscovery) Shutdown(ctx context.Context) error { nd.logger.Info("Shutting down node discovery") // Close any connections if needed return nil } // MessageHandler processes a specific UDP message type type MessageHandler func(payload string, remoteAddr *net.UDPAddr) // handleUDPMessage processes incoming UDP messages func (nd *NodeDiscovery) handleUDPMessage(message string, remoteAddr *net.UDPAddr) { nd.logger.WithFields(log.Fields{ "message": message, "from": remoteAddr.String(), }).Debug("UDP message received") message = strings.TrimSpace(message) // Extract topic by splitting on first ":" parts := strings.SplitN(message, ":", 2) if len(parts) < 2 { nd.logger.WithField("message", message).Debug("Invalid message format - missing ':' separator") return } topic := parts[0] payload := parts[1] // Handler map for different message types handlers := map[string]MessageHandler{ "cluster/heartbeat": func(payload string, remoteAddr *net.UDPAddr) { nd.updateNodeFromHeartbeat(remoteAddr.IP.String(), remoteAddr.Port, payload) }, "node/update": func(payload string, remoteAddr *net.UDPAddr) { // Reconstruct full message for handleNodeUpdate which expects "node/update:hostname:{json}" fullMessage := "node/update:" + payload nd.handleNodeUpdate(remoteAddr.IP.String(), fullMessage) }, "raw": func(payload string, remoteAddr *net.UDPAddr) { nd.logger.WithField("message", "raw:"+payload).Debug("Received raw message") }, "cluster/event": func(payload string, remoteAddr *net.UDPAddr) { nd.logger.WithField("message", "cluster/event:"+payload).Debug("Received cluster/event message") }, } // Look up and execute handler if handler, exists := handlers[topic]; exists { handler(payload, remoteAddr) } else { nd.logger.WithField("topic", topic).Debug("Received unknown UDP message type") } } // updateNodeFromHeartbeat updates node information from heartbeat messages func (nd *NodeDiscovery) updateNodeFromHeartbeat(sourceIP string, sourcePort int, hostname string) { nd.mutex.Lock() defer nd.mutex.Unlock() now := time.Now() existingNode, exists := nd.discoveredNodes[sourceIP] if exists { // Update existing node wasStale := existingNode.Status == NodeStatusInactive oldHostname := existingNode.Hostname existingNode.LastSeen = now existingNode.Hostname = hostname existingNode.Status = NodeStatusActive nd.logger.WithFields(log.Fields{ "ip": sourceIP, "hostname": hostname, "total": len(nd.discoveredNodes), }).Debug("Heartbeat from existing node") // Check if hostname changed if oldHostname != hostname { nd.logger.WithFields(log.Fields{ "ip": sourceIP, "old_hostname": oldHostname, "new_hostname": hostname, }).Info("Hostname updated") } // Notify callbacks action := "active" if wasStale { action = "node became active" } else if oldHostname != hostname { action = "hostname update" } nd.notifyCallbacks(sourceIP, action) } else { // Create new node entry nodeInfo := &NodeInfo{ IP: sourceIP, Port: sourcePort, Hostname: hostname, Status: NodeStatusActive, DiscoveredAt: now, LastSeen: now, } nd.discoveredNodes[sourceIP] = nodeInfo nd.logger.WithFields(log.Fields{ "ip": sourceIP, "hostname": hostname, "total": len(nd.discoveredNodes), }).Info("New node discovered via heartbeat") // Set as primary node if this is the first one if nd.primaryNode == "" { nd.primaryNode = sourceIP nd.logger.WithField("ip", sourceIP).Info("Set as primary node") } // Notify callbacks nd.notifyCallbacks(sourceIP, "discovered") } } // handleNodeUpdate processes NODE_UPDATE and node/update messages func (nd *NodeDiscovery) handleNodeUpdate(sourceIP, message string) { // Message format: "NODE_UPDATE:hostname:{json}" or "node/update:hostname:{json}" parts := strings.SplitN(message, ":", 3) if len(parts) < 3 { nd.logger.WithField("message", message).Warn("Invalid NODE_UPDATE message format") return } hostname := parts[1] // jsonData := parts[2] // TODO: Parse JSON data when needed nd.mutex.Lock() defer nd.mutex.Unlock() existingNode := nd.discoveredNodes[sourceIP] if existingNode == nil { nd.logger.WithField("ip", sourceIP).Warn("Received NODE_UPDATE for unknown node") return } // Update node information from JSON data // For now, we'll just update the hostname and mark as active // TODO: Parse and update other fields from JSON existingNode.Hostname = hostname existingNode.LastSeen = time.Now() existingNode.Status = NodeStatusActive nd.logger.WithFields(log.Fields{ "ip": sourceIP, "hostname": hostname, }).Debug("Updated node from NODE_UPDATE") nd.notifyCallbacks(sourceIP, "node update") } // startCleanupRoutine periodically marks stale nodes as inactive func (nd *NodeDiscovery) startCleanupRoutine() { ticker := time.NewTicker(2 * time.Second) // Check every 2 seconds defer ticker.Stop() for range ticker.C { nd.markStaleNodes() nd.updatePrimaryNode() } } // markStaleNodes marks nodes as inactive if they haven't been seen recently func (nd *NodeDiscovery) markStaleNodes() { nd.mutex.Lock() defer nd.mutex.Unlock() now := time.Now() var nodesMarkedStale bool for ip, node := range nd.discoveredNodes { timeSinceLastSeen := now.Sub(node.LastSeen) if timeSinceLastSeen > nd.staleThreshold && node.Status != NodeStatusInactive { nd.logger.WithFields(log.Fields{ "ip": ip, "hostname": node.Hostname, "time_since_seen": timeSinceLastSeen, "threshold": nd.staleThreshold, }).Info("Node marked inactive") node.Status = NodeStatusInactive nodesMarkedStale = true nd.notifyCallbacks(ip, "stale") // If this was our primary node, clear it if nd.primaryNode == ip { nd.primaryNode = "" nd.logger.Info("Primary node became stale, clearing selection") } } } if nodesMarkedStale { nd.logger.Debug("Nodes marked stale, triggering cluster update") } } // updatePrimaryNode selects a new primary node if needed func (nd *NodeDiscovery) updatePrimaryNode() { nd.mutex.Lock() defer nd.mutex.Unlock() // If we don't have a primary node or current primary is not active, select a new one if nd.primaryNode == "" || nd.discoveredNodes[nd.primaryNode].Status != NodeStatusActive { nd.selectBestPrimaryNode() } } // selectBestPrimaryNode selects the most recently seen active node as primary func (nd *NodeDiscovery) selectBestPrimaryNode() { if len(nd.discoveredNodes) == 0 { return } // If current primary is still active, keep it if nd.primaryNode != "" { if node, exists := nd.discoveredNodes[nd.primaryNode]; exists && node.Status == NodeStatusActive { return } } // Find the most recently seen active node var bestNode string var mostRecent time.Time for ip, node := range nd.discoveredNodes { if node.Status == NodeStatusActive && node.LastSeen.After(mostRecent) { mostRecent = node.LastSeen bestNode = ip } } if bestNode != "" && bestNode != nd.primaryNode { nd.primaryNode = bestNode nd.logger.WithField("ip", bestNode).Info("Selected new primary node") nd.notifyCallbacks(bestNode, "primary node change") } } // notifyCallbacks notifies all registered callbacks about node changes func (nd *NodeDiscovery) notifyCallbacks(nodeIP, action string) { for _, callback := range nd.callbacks { go callback(nodeIP, action) } } // GetNodes returns a copy of all discovered nodes func (nd *NodeDiscovery) GetNodes() map[string]*NodeInfo { nd.mutex.RLock() defer nd.mutex.RUnlock() nodes := make(map[string]*NodeInfo) for ip, node := range nd.discoveredNodes { nodes[ip] = node } return nodes } // GetPrimaryNode returns the current primary node IP func (nd *NodeDiscovery) GetPrimaryNode() string { nd.mutex.RLock() defer nd.mutex.RUnlock() return nd.primaryNode } // SetPrimaryNode manually sets the primary node func (nd *NodeDiscovery) SetPrimaryNode(ip string) error { nd.mutex.Lock() defer nd.mutex.Unlock() if _, exists := nd.discoveredNodes[ip]; !exists { return fmt.Errorf("node %s not found", ip) } nd.primaryNode = ip nd.notifyCallbacks(ip, "manual primary node setting") return nil } // SelectRandomPrimaryNode selects a random active node as primary func (nd *NodeDiscovery) SelectRandomPrimaryNode() string { nd.mutex.Lock() defer nd.mutex.Unlock() if len(nd.discoveredNodes) == 0 { return "" } // Get active nodes excluding current primary var activeNodes []string for ip, node := range nd.discoveredNodes { if node.Status == NodeStatusActive && ip != nd.primaryNode { activeNodes = append(activeNodes, ip) } } if len(activeNodes) == 0 { // No other active nodes, keep current primary return nd.primaryNode } // Select random node (simple implementation) randomIndex := time.Now().UnixNano() % int64(len(activeNodes)) randomNode := activeNodes[randomIndex] nd.primaryNode = randomNode nd.logger.WithField("ip", randomNode).Info("Randomly selected new primary node") nd.notifyCallbacks(randomNode, "random primary node selection") return randomNode } // AddCallback registers a callback for node updates func (nd *NodeDiscovery) AddCallback(callback NodeUpdateCallback) { nd.mutex.Lock() defer nd.mutex.Unlock() nd.callbacks = append(nd.callbacks, callback) } // GetClusterStatus returns current cluster status func (nd *NodeDiscovery) GetClusterStatus() ClusterStatus { nd.mutex.RLock() defer nd.mutex.RUnlock() return ClusterStatus{ PrimaryNode: nd.primaryNode, TotalNodes: len(nd.discoveredNodes), UDPPort: nd.udpPort, ServerRunning: true, } } // Helper function to parse port string to int func parsePort(portStr string) int { port := 4210 // default if p, err := net.LookupPort("udp", portStr); err == nil { port = p } else if p, err := strconv.Atoi(portStr); err == nil { port = p } return port }