Files
spore-gateway/internal/discovery/discovery.go
2025-10-27 15:18:44 +01:00

452 lines
12 KiB
Go

package discovery
import (
"context"
"encoding/json"
"fmt"
"net"
"strconv"
"strings"
"time"
log "github.com/sirupsen/logrus"
)
// Start starts the UDP discovery server
func (nd *NodeDiscovery) Start() error {
addr := net.UDPAddr{
Port: parsePort(nd.udpPort),
IP: net.ParseIP("0.0.0.0"),
}
conn, err := net.ListenUDP("udp", &addr)
if err != nil {
return fmt.Errorf("failed to listen on UDP port %s: %w", nd.udpPort, err)
}
nd.logger.WithField("port", nd.udpPort).Info("UDP heartbeat server listening")
// Start cleanup routine
go nd.startCleanupRoutine()
// Handle incoming UDP messages
buffer := make([]byte, 1024)
for {
n, remoteAddr, err := conn.ReadFromUDP(buffer)
if err != nil {
nd.logger.WithError(err).Error("Error reading UDP message")
continue
}
message := string(buffer[:n])
nd.handleUDPMessage(message, remoteAddr)
}
}
// Shutdown gracefully shuts down the discovery server
func (nd *NodeDiscovery) Shutdown(ctx context.Context) error {
nd.logger.Info("Shutting down node discovery")
// Close any connections if needed
return nil
}
// MessageHandler processes a specific UDP message type
type MessageHandler func(payload string, remoteAddr *net.UDPAddr)
// handleUDPMessage processes incoming UDP messages
func (nd *NodeDiscovery) handleUDPMessage(message string, remoteAddr *net.UDPAddr) {
nd.logger.WithFields(log.Fields{
"message": message,
"from": remoteAddr.String(),
}).Debug("UDP message received")
message = strings.TrimSpace(message)
// Extract topic by splitting on first ":"
parts := strings.SplitN(message, ":", 2)
if len(parts) < 2 {
nd.logger.WithField("message", message).Debug("Invalid message format - missing ':' separator")
return
}
topic := parts[0]
payload := parts[1]
// Handler map for different message types
handlers := map[string]MessageHandler{
"cluster/heartbeat": func(payload string, remoteAddr *net.UDPAddr) {
nd.updateNodeFromHeartbeat(remoteAddr.IP.String(), remoteAddr.Port, payload)
},
"node/update": func(payload string, remoteAddr *net.UDPAddr) {
// Reconstruct full message for handleNodeUpdate which expects "node/update:hostname:{json}"
fullMessage := "node/update:" + payload
nd.handleNodeUpdate(remoteAddr.IP.String(), fullMessage)
},
"RAW": func(payload string, remoteAddr *net.UDPAddr) {
nd.logger.WithField("message", "RAW:"+payload).Debug("Received raw message")
},
"cluster/event": func(payload string, remoteAddr *net.UDPAddr) {
nd.handleClusterEvent(payload, remoteAddr)
},
"cluster/broadcast": func(payload string, remoteAddr *net.UDPAddr) {
nd.handleClusterBroadcast(payload, remoteAddr)
},
}
// Look up and execute handler
if handler, exists := handlers[topic]; exists {
handler(payload, remoteAddr)
} else {
nd.logger.WithField("topic", topic).Debug("Received unknown UDP message type")
}
}
// updateNodeFromHeartbeat updates node information from heartbeat messages
func (nd *NodeDiscovery) updateNodeFromHeartbeat(sourceIP string, sourcePort int, hostname string) {
nd.mutex.Lock()
defer nd.mutex.Unlock()
now := time.Now()
existingNode, exists := nd.discoveredNodes[sourceIP]
if exists {
// Update existing node
wasStale := existingNode.Status == NodeStatusInactive
oldHostname := existingNode.Hostname
existingNode.LastSeen = now
existingNode.Hostname = hostname
existingNode.Status = NodeStatusActive
nd.logger.WithFields(log.Fields{
"ip": sourceIP,
"hostname": hostname,
"total": len(nd.discoveredNodes),
}).Debug("Heartbeat from existing node")
// Check if hostname changed
if oldHostname != hostname {
nd.logger.WithFields(log.Fields{
"ip": sourceIP,
"old_hostname": oldHostname,
"new_hostname": hostname,
}).Info("Hostname updated")
}
// Notify callbacks
action := "active"
if wasStale {
action = "node became active"
} else if oldHostname != hostname {
action = "hostname update"
}
nd.notifyCallbacks(sourceIP, action)
} else {
// Create new node entry
nodeInfo := &NodeInfo{
IP: sourceIP,
Port: sourcePort,
Hostname: hostname,
Status: NodeStatusActive,
DiscoveredAt: now,
LastSeen: now,
}
nd.discoveredNodes[sourceIP] = nodeInfo
nd.logger.WithFields(log.Fields{
"ip": sourceIP,
"hostname": hostname,
"total": len(nd.discoveredNodes),
}).Info("New node discovered via heartbeat")
// Set as primary node if this is the first one
if nd.primaryNode == "" {
nd.primaryNode = sourceIP
nd.logger.WithField("ip", sourceIP).Info("Set as primary node")
}
// Notify callbacks
nd.notifyCallbacks(sourceIP, "discovered")
}
}
// handleNodeUpdate processes NODE_UPDATE and node/update messages
func (nd *NodeDiscovery) handleNodeUpdate(sourceIP, message string) {
// Message format: "NODE_UPDATE:hostname:{json}" or "node/update:hostname:{json}"
parts := strings.SplitN(message, ":", 3)
if len(parts) < 3 {
nd.logger.WithField("message", message).Warn("Invalid NODE_UPDATE message format")
return
}
hostname := parts[1]
// jsonData := parts[2] // TODO: Parse JSON data when needed
nd.mutex.Lock()
defer nd.mutex.Unlock()
existingNode := nd.discoveredNodes[sourceIP]
if existingNode == nil {
nd.logger.WithField("ip", sourceIP).Warn("Received NODE_UPDATE for unknown node")
return
}
// Update node information from JSON data
// For now, we'll just update the hostname and mark as active
// TODO: Parse and update other fields from JSON
existingNode.Hostname = hostname
existingNode.LastSeen = time.Now()
existingNode.Status = NodeStatusActive
nd.logger.WithFields(log.Fields{
"ip": sourceIP,
"hostname": hostname,
}).Debug("Updated node from NODE_UPDATE")
nd.notifyCallbacks(sourceIP, "node update")
}
// startCleanupRoutine periodically marks stale nodes as inactive
func (nd *NodeDiscovery) startCleanupRoutine() {
ticker := time.NewTicker(2 * time.Second) // Check every 2 seconds
defer ticker.Stop()
for range ticker.C {
nd.markStaleNodes()
nd.updatePrimaryNode()
}
}
// markStaleNodes marks nodes as inactive if they haven't been seen recently
func (nd *NodeDiscovery) markStaleNodes() {
nd.mutex.Lock()
defer nd.mutex.Unlock()
now := time.Now()
var nodesMarkedStale bool
for ip, node := range nd.discoveredNodes {
timeSinceLastSeen := now.Sub(node.LastSeen)
if timeSinceLastSeen > nd.staleThreshold && node.Status != NodeStatusInactive {
nd.logger.WithFields(log.Fields{
"ip": ip,
"hostname": node.Hostname,
"time_since_seen": timeSinceLastSeen,
"threshold": nd.staleThreshold,
}).Info("Node marked inactive")
node.Status = NodeStatusInactive
nodesMarkedStale = true
nd.notifyCallbacks(ip, "stale")
// If this was our primary node, clear it
if nd.primaryNode == ip {
nd.primaryNode = ""
nd.logger.Info("Primary node became stale, clearing selection")
}
}
}
if nodesMarkedStale {
nd.logger.Debug("Nodes marked stale, triggering cluster update")
}
}
// updatePrimaryNode selects a new primary node if needed
func (nd *NodeDiscovery) updatePrimaryNode() {
nd.mutex.Lock()
defer nd.mutex.Unlock()
// If we don't have a primary node or current primary is not active, select a new one
if nd.primaryNode == "" || nd.discoveredNodes[nd.primaryNode].Status != NodeStatusActive {
nd.selectBestPrimaryNode()
}
}
// selectBestPrimaryNode selects the most recently seen active node as primary
func (nd *NodeDiscovery) selectBestPrimaryNode() {
if len(nd.discoveredNodes) == 0 {
return
}
// If current primary is still active, keep it
if nd.primaryNode != "" {
if node, exists := nd.discoveredNodes[nd.primaryNode]; exists && node.Status == NodeStatusActive {
return
}
}
// Find the most recently seen active node
var bestNode string
var mostRecent time.Time
for ip, node := range nd.discoveredNodes {
if node.Status == NodeStatusActive && node.LastSeen.After(mostRecent) {
mostRecent = node.LastSeen
bestNode = ip
}
}
if bestNode != "" && bestNode != nd.primaryNode {
nd.primaryNode = bestNode
nd.logger.WithField("ip", bestNode).Info("Selected new primary node")
nd.notifyCallbacks(bestNode, "primary node change")
}
}
// notifyCallbacks notifies all registered callbacks about node changes
func (nd *NodeDiscovery) notifyCallbacks(nodeIP, action string) {
for _, callback := range nd.callbacks {
go callback(nodeIP, action)
}
}
// GetNodes returns a copy of all discovered nodes
func (nd *NodeDiscovery) GetNodes() map[string]*NodeInfo {
nd.mutex.RLock()
defer nd.mutex.RUnlock()
nodes := make(map[string]*NodeInfo)
for ip, node := range nd.discoveredNodes {
nodes[ip] = node
}
return nodes
}
// GetPrimaryNode returns the current primary node IP
func (nd *NodeDiscovery) GetPrimaryNode() string {
nd.mutex.RLock()
defer nd.mutex.RUnlock()
return nd.primaryNode
}
// SetPrimaryNode manually sets the primary node
func (nd *NodeDiscovery) SetPrimaryNode(ip string) error {
nd.mutex.Lock()
defer nd.mutex.Unlock()
if _, exists := nd.discoveredNodes[ip]; !exists {
return fmt.Errorf("node %s not found", ip)
}
nd.primaryNode = ip
nd.notifyCallbacks(ip, "manual primary node setting")
return nil
}
// SelectRandomPrimaryNode selects a random active node as primary
func (nd *NodeDiscovery) SelectRandomPrimaryNode() string {
nd.mutex.Lock()
defer nd.mutex.Unlock()
if len(nd.discoveredNodes) == 0 {
return ""
}
// Get active nodes excluding current primary
var activeNodes []string
for ip, node := range nd.discoveredNodes {
if node.Status == NodeStatusActive && ip != nd.primaryNode {
activeNodes = append(activeNodes, ip)
}
}
if len(activeNodes) == 0 {
// No other active nodes, keep current primary
return nd.primaryNode
}
// Select random node (simple implementation)
randomIndex := time.Now().UnixNano() % int64(len(activeNodes))
randomNode := activeNodes[randomIndex]
nd.primaryNode = randomNode
nd.logger.WithField("ip", randomNode).Info("Randomly selected new primary node")
nd.notifyCallbacks(randomNode, "random primary node selection")
return randomNode
}
// AddCallback registers a callback for node updates
func (nd *NodeDiscovery) AddCallback(callback NodeUpdateCallback) {
nd.mutex.Lock()
defer nd.mutex.Unlock()
nd.callbacks = append(nd.callbacks, callback)
}
// SetClusterEventCallback sets the callback for cluster events
func (nd *NodeDiscovery) SetClusterEventCallback(callback ClusterEventBroadcaster) {
nd.mutex.Lock()
defer nd.mutex.Unlock()
nd.clusterEventCallback = callback
}
// handleClusterEvent processes cluster/event messages
func (nd *NodeDiscovery) handleClusterEvent(payload string, remoteAddr *net.UDPAddr) {
nd.logger.WithFields(log.Fields{
"payload": payload,
"from": remoteAddr.String(),
}).Debug("Received cluster/event message")
// Forward to websocket if callback is set
if nd.clusterEventCallback != nil {
nd.clusterEventCallback.BroadcastClusterEvent("cluster/event", payload)
}
}
// handleClusterBroadcast processes cluster/broadcast messages
func (nd *NodeDiscovery) handleClusterBroadcast(payload string, remoteAddr *net.UDPAddr) {
nd.logger.WithFields(log.Fields{
"payload": payload,
"from": remoteAddr.String(),
}).Debug("Received cluster/broadcast message")
// Parse the payload JSON to extract nested event and data
var payloadData struct {
Event string `json:"event"`
Data interface{} `json:"data"`
}
if err := json.Unmarshal([]byte(payload), &payloadData); err != nil {
nd.logger.WithError(err).Error("Failed to parse cluster/broadcast payload")
return
}
nd.logger.WithFields(log.Fields{
"event": payloadData.Event,
"from": remoteAddr.String(),
}).Debug("Parsed cluster/broadcast payload")
// Forward to websocket if callback is set, mapping event to topic and data to data
if nd.clusterEventCallback != nil {
nd.clusterEventCallback.BroadcastClusterEvent(payloadData.Event, payloadData.Data)
}
}
// GetClusterStatus returns current cluster status
func (nd *NodeDiscovery) GetClusterStatus() ClusterStatus {
nd.mutex.RLock()
defer nd.mutex.RUnlock()
return ClusterStatus{
PrimaryNode: nd.primaryNode,
TotalNodes: len(nd.discoveredNodes),
UDPPort: nd.udpPort,
ServerRunning: true,
}
}
// Helper function to parse port string to int
func parsePort(portStr string) int {
port := 4210 // default
if p, err := net.LookupPort("udp", portStr); err == nil {
port = p
} else if p, err := strconv.Atoi(portStr); err == nil {
port = p
}
return port
}