From 5e1c39b0bfe1bbba004ec4cc8ecbf60619711f58 Mon Sep 17 00:00:00 2001 From: 0x1d Date: Sun, 19 Oct 2025 21:54:14 +0200 Subject: [PATCH] feat: initial gateway implementation --- README.md | 104 +++++ go.mod | 11 + go.sum | 19 + internal/discovery/discovery.go | 369 ++++++++++++++++ internal/discovery/types.go | 63 +++ internal/server/server.go | 745 ++++++++++++++++++++++++++++++++ internal/websocket/websocket.go | 370 ++++++++++++++++ main.go | 104 +++++ pkg/client/client.go | 399 +++++++++++++++++ pkg/config/config.go | 39 ++ 10 files changed, 2223 insertions(+) create mode 100644 README.md create mode 100644 go.mod create mode 100644 go.sum create mode 100644 internal/discovery/discovery.go create mode 100644 internal/discovery/types.go create mode 100644 internal/server/server.go create mode 100644 internal/websocket/websocket.go create mode 100644 main.go create mode 100644 pkg/client/client.go create mode 100644 pkg/config/config.go diff --git a/README.md b/README.md new file mode 100644 index 0000000..d5741f6 --- /dev/null +++ b/README.md @@ -0,0 +1,104 @@ +# SPORE Gateway + +A Go-based gateway service that replicates the functionality of the Node.js spore-ui server, providing UDP-based node discovery, cluster management, and HTTP API endpoints for SPORE devices. + +## Features + +- **UDP Node Discovery**: Listens for heartbeat messages from SPORE nodes on port 4210 +- **Cluster Management**: Tracks node status, manages primary node selection, and handles failover +- **HTTP API Server**: Provides REST endpoints for cluster information, node management, and proxy calls +- **WebSocket Server**: Real-time cluster updates via WebSocket connections +- **Failover Logic**: Automatic switching between SPORE nodes when primary fails +- **Proxy Functionality**: Generic proxy calls to SPORE node capabilities + +## Installation + +```bash +go mod tidy +go build +``` + +## Usage + +```bash +./spore-gateway [options] + +Options: + -port string + HTTP server port (default "3001") + -udp-port string + UDP discovery port (default "4210") + -log-level string + Log level (debug, info, warn, error) (default "info") +``` + +## Integration + +The spore-gateway works together with the SPORE UI frontend: + +- **spore-gateway**: Runs on port 3001, handles UDP discovery, API endpoints, and WebSocket connections +- **spore-ui**: Runs on port 3000, serves the frontend interface and connects to spore-gateway for all backend functionality + +Start spore-gateway first, then start the frontend: +```bash +# Terminal 1 - Start backend +./spore-gateway + +# Terminal 2 - Start frontend +cd ../spore-ui && npm start +``` + +Access the UI at: `http://localhost:3000` + +## API Endpoints + +### Discovery +- `GET /api/discovery/nodes` - Get discovered nodes and cluster status +- `POST /api/discovery/refresh` - Refresh cluster state +- `POST /api/discovery/random-primary` - Select random primary node +- `POST /api/discovery/primary/{ip}` - Set specific node as primary + +### Cluster +- `GET /api/cluster/members` - Get cluster member information +- `POST /api/cluster/refresh` - Trigger cluster refresh + +### Tasks +- `GET /api/tasks/status` - Get task status (optionally for specific node) + +### Nodes +- `GET /api/node/status` - Get system status +- `GET /api/node/status/{ip}` - Get system status for specific node +- `GET /api/node/endpoints` - Get available API endpoints +- `POST /api/node/update` - Upload firmware to node + +### Proxy +- `POST /api/proxy-call` - Make generic proxy call to SPORE node + +### Testing +- `POST /api/test/websocket` - Test WebSocket broadcasting + +### Health +- `GET /api/health` - Health check endpoint + +## WebSocket + +Connect to `/ws` for real-time cluster updates. + +## Development + +The application follows the same patterns as the original Node.js spore-ui server but is implemented in Go with: + +- Structured logging using logrus +- Graceful shutdown handling +- Concurrent-safe node management +- HTTP middleware for CORS and logging +- WebSocket support for real-time updates + +## Architecture + +- `main.go` - Application entry point +- `internal/discovery/` - UDP-based node discovery +- `internal/server/` - HTTP API server +- `internal/websocket/` - WebSocket server for real-time updates +- `pkg/client/` - SPORE API client +- `pkg/config/` - Configuration management diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..ab39b83 --- /dev/null +++ b/go.mod @@ -0,0 +1,11 @@ +module spore-gateway + +go 1.21 + +require ( + github.com/gorilla/mux v1.8.0 + github.com/gorilla/websocket v1.5.0 + github.com/sirupsen/logrus v1.9.3 +) + +require golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..11eab44 --- /dev/null +++ b/go.sum @@ -0,0 +1,19 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= +github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= +github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= +github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 h1:0A+M6Uqn+Eje4kHMK80dtF3JCXC4ykBgQG4Fe06QRhQ= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go new file mode 100644 index 0000000..c57cd43 --- /dev/null +++ b/internal/discovery/discovery.go @@ -0,0 +1,369 @@ +package discovery + +import ( + "context" + "fmt" + "net" + "strconv" + "strings" + "time" + + log "github.com/sirupsen/logrus" +) + +// Start starts the UDP discovery server +func (nd *NodeDiscovery) Start() error { + addr := net.UDPAddr{ + Port: parsePort(nd.udpPort), + IP: net.ParseIP("0.0.0.0"), + } + + conn, err := net.ListenUDP("udp", &addr) + if err != nil { + return fmt.Errorf("failed to listen on UDP port %s: %w", nd.udpPort, err) + } + + nd.logger.WithField("port", nd.udpPort).Info("UDP heartbeat server listening") + + // Start cleanup routine + go nd.startCleanupRoutine() + + // Handle incoming UDP messages + buffer := make([]byte, 1024) + for { + n, remoteAddr, err := conn.ReadFromUDP(buffer) + if err != nil { + nd.logger.WithError(err).Error("Error reading UDP message") + continue + } + + message := string(buffer[:n]) + nd.handleUDPMessage(message, remoteAddr) + } +} + +// Shutdown gracefully shuts down the discovery server +func (nd *NodeDiscovery) Shutdown(ctx context.Context) error { + nd.logger.Info("Shutting down node discovery") + // Close any connections if needed + return nil +} + +// handleUDPMessage processes incoming UDP messages +func (nd *NodeDiscovery) handleUDPMessage(message string, remoteAddr *net.UDPAddr) { + nd.logger.WithFields(log.Fields{ + "message": message, + "from": remoteAddr.String(), + }).Debug("UDP message received") + + message = strings.TrimSpace(message) + + if strings.HasPrefix(message, "CLUSTER_HEARTBEAT:") { + hostname := strings.TrimPrefix(message, "CLUSTER_HEARTBEAT:") + nd.updateNodeFromHeartbeat(remoteAddr.IP.String(), remoteAddr.Port, hostname) + } else if strings.HasPrefix(message, "NODE_UPDATE:") { + nd.handleNodeUpdate(remoteAddr.IP.String(), message) + } else if !strings.HasPrefix(message, "RAW:") { + nd.logger.WithField("message", message).Debug("Received unknown UDP message") + } +} + +// updateNodeFromHeartbeat updates node information from heartbeat messages +func (nd *NodeDiscovery) updateNodeFromHeartbeat(sourceIP string, sourcePort int, hostname string) { + nd.mutex.Lock() + defer nd.mutex.Unlock() + + now := time.Now() + existingNode, exists := nd.discoveredNodes[sourceIP] + + if exists { + // Update existing node + wasStale := existingNode.Status == NodeStatusInactive + oldHostname := existingNode.Hostname + + existingNode.LastSeen = now + existingNode.Hostname = hostname + existingNode.Status = NodeStatusActive + + nd.logger.WithFields(log.Fields{ + "ip": sourceIP, + "hostname": hostname, + "total": len(nd.discoveredNodes), + }).Debug("Heartbeat from existing node") + + // Check if hostname changed + if oldHostname != hostname { + nd.logger.WithFields(log.Fields{ + "ip": sourceIP, + "old_hostname": oldHostname, + "new_hostname": hostname, + }).Info("Hostname updated") + } + + // Notify callbacks + action := "active" + if wasStale { + action = "node became active" + } else if oldHostname != hostname { + action = "hostname update" + } + nd.notifyCallbacks(sourceIP, action) + } else { + // Create new node entry + nodeInfo := &NodeInfo{ + IP: sourceIP, + Port: sourcePort, + Hostname: hostname, + Status: NodeStatusActive, + DiscoveredAt: now, + LastSeen: now, + } + + nd.discoveredNodes[sourceIP] = nodeInfo + + nd.logger.WithFields(log.Fields{ + "ip": sourceIP, + "hostname": hostname, + "total": len(nd.discoveredNodes), + }).Info("New node discovered via heartbeat") + + // Set as primary node if this is the first one + if nd.primaryNode == "" { + nd.primaryNode = sourceIP + nd.logger.WithField("ip", sourceIP).Info("Set as primary node") + } + + // Notify callbacks + nd.notifyCallbacks(sourceIP, "discovered") + } +} + +// handleNodeUpdate processes NODE_UPDATE messages +func (nd *NodeDiscovery) handleNodeUpdate(sourceIP, message string) { + // Message format: "NODE_UPDATE:hostname:{json}" + parts := strings.SplitN(message, ":", 3) + if len(parts) < 3 { + nd.logger.WithField("message", message).Warn("Invalid NODE_UPDATE message format") + return + } + + hostname := parts[1] + // jsonData := parts[2] // TODO: Parse JSON data when needed + + nd.mutex.Lock() + defer nd.mutex.Unlock() + + existingNode := nd.discoveredNodes[sourceIP] + if existingNode == nil { + nd.logger.WithField("ip", sourceIP).Warn("Received NODE_UPDATE for unknown node") + return + } + + // Update node information from JSON data + // For now, we'll just update the hostname and mark as active + // TODO: Parse and update other fields from JSON + existingNode.Hostname = hostname + existingNode.LastSeen = time.Now() + existingNode.Status = NodeStatusActive + + nd.logger.WithFields(log.Fields{ + "ip": sourceIP, + "hostname": hostname, + }).Debug("Updated node from NODE_UPDATE") + + nd.notifyCallbacks(sourceIP, "node update") +} + +// startCleanupRoutine periodically marks stale nodes as inactive +func (nd *NodeDiscovery) startCleanupRoutine() { + ticker := time.NewTicker(2 * time.Second) // Check every 2 seconds + defer ticker.Stop() + + for range ticker.C { + nd.markStaleNodes() + nd.updatePrimaryNode() + } +} + +// markStaleNodes marks nodes as inactive if they haven't been seen recently +func (nd *NodeDiscovery) markStaleNodes() { + nd.mutex.Lock() + defer nd.mutex.Unlock() + + now := time.Now() + var nodesMarkedStale bool + + for ip, node := range nd.discoveredNodes { + timeSinceLastSeen := now.Sub(node.LastSeen) + + if timeSinceLastSeen > nd.staleThreshold && node.Status != NodeStatusInactive { + nd.logger.WithFields(log.Fields{ + "ip": ip, + "hostname": node.Hostname, + "time_since_seen": timeSinceLastSeen, + "threshold": nd.staleThreshold, + }).Info("Node marked inactive") + + node.Status = NodeStatusInactive + nodesMarkedStale = true + + nd.notifyCallbacks(ip, "stale") + + // If this was our primary node, clear it + if nd.primaryNode == ip { + nd.primaryNode = "" + nd.logger.Info("Primary node became stale, clearing selection") + } + } + } + + if nodesMarkedStale { + nd.logger.Debug("Nodes marked stale, triggering cluster update") + } +} + +// updatePrimaryNode selects a new primary node if needed +func (nd *NodeDiscovery) updatePrimaryNode() { + nd.mutex.Lock() + defer nd.mutex.Unlock() + + // If we don't have a primary node or current primary is not active, select a new one + if nd.primaryNode == "" || nd.discoveredNodes[nd.primaryNode].Status != NodeStatusActive { + nd.selectBestPrimaryNode() + } +} + +// selectBestPrimaryNode selects the most recently seen active node as primary +func (nd *NodeDiscovery) selectBestPrimaryNode() { + if len(nd.discoveredNodes) == 0 { + return + } + + // If current primary is still active, keep it + if nd.primaryNode != "" { + if node, exists := nd.discoveredNodes[nd.primaryNode]; exists && node.Status == NodeStatusActive { + return + } + } + + // Find the most recently seen active node + var bestNode string + var mostRecent time.Time + + for ip, node := range nd.discoveredNodes { + if node.Status == NodeStatusActive && node.LastSeen.After(mostRecent) { + mostRecent = node.LastSeen + bestNode = ip + } + } + + if bestNode != "" && bestNode != nd.primaryNode { + nd.primaryNode = bestNode + nd.logger.WithField("ip", bestNode).Info("Selected new primary node") + nd.notifyCallbacks(bestNode, "primary node change") + } +} + +// notifyCallbacks notifies all registered callbacks about node changes +func (nd *NodeDiscovery) notifyCallbacks(nodeIP, action string) { + for _, callback := range nd.callbacks { + go callback(nodeIP, action) + } +} + +// GetNodes returns a copy of all discovered nodes +func (nd *NodeDiscovery) GetNodes() map[string]*NodeInfo { + nd.mutex.RLock() + defer nd.mutex.RUnlock() + + nodes := make(map[string]*NodeInfo) + for ip, node := range nd.discoveredNodes { + nodes[ip] = node + } + return nodes +} + +// GetPrimaryNode returns the current primary node IP +func (nd *NodeDiscovery) GetPrimaryNode() string { + nd.mutex.RLock() + defer nd.mutex.RUnlock() + return nd.primaryNode +} + +// SetPrimaryNode manually sets the primary node +func (nd *NodeDiscovery) SetPrimaryNode(ip string) error { + nd.mutex.Lock() + defer nd.mutex.Unlock() + + if _, exists := nd.discoveredNodes[ip]; !exists { + return fmt.Errorf("node %s not found", ip) + } + + nd.primaryNode = ip + nd.notifyCallbacks(ip, "manual primary node setting") + return nil +} + +// SelectRandomPrimaryNode selects a random active node as primary +func (nd *NodeDiscovery) SelectRandomPrimaryNode() string { + nd.mutex.Lock() + defer nd.mutex.Unlock() + + if len(nd.discoveredNodes) == 0 { + return "" + } + + // Get active nodes excluding current primary + var activeNodes []string + for ip, node := range nd.discoveredNodes { + if node.Status == NodeStatusActive && ip != nd.primaryNode { + activeNodes = append(activeNodes, ip) + } + } + + if len(activeNodes) == 0 { + // No other active nodes, keep current primary + return nd.primaryNode + } + + // Select random node (simple implementation) + randomIndex := time.Now().UnixNano() % int64(len(activeNodes)) + randomNode := activeNodes[randomIndex] + + nd.primaryNode = randomNode + nd.logger.WithField("ip", randomNode).Info("Randomly selected new primary node") + nd.notifyCallbacks(randomNode, "random primary node selection") + + return randomNode +} + +// AddCallback registers a callback for node updates +func (nd *NodeDiscovery) AddCallback(callback NodeUpdateCallback) { + nd.mutex.Lock() + defer nd.mutex.Unlock() + nd.callbacks = append(nd.callbacks, callback) +} + +// GetClusterStatus returns current cluster status +func (nd *NodeDiscovery) GetClusterStatus() ClusterStatus { + nd.mutex.RLock() + defer nd.mutex.RUnlock() + + return ClusterStatus{ + PrimaryNode: nd.primaryNode, + TotalNodes: len(nd.discoveredNodes), + UDPPort: nd.udpPort, + ServerRunning: true, + } +} + +// Helper function to parse port string to int +func parsePort(portStr string) int { + port := 4210 // default + if p, err := net.LookupPort("udp", portStr); err == nil { + port = p + } else if p, err := strconv.Atoi(portStr); err == nil { + port = p + } + return port +} diff --git a/internal/discovery/types.go b/internal/discovery/types.go new file mode 100644 index 0000000..6fa0f24 --- /dev/null +++ b/internal/discovery/types.go @@ -0,0 +1,63 @@ +package discovery + +import ( + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +// NodeStatus represents the status of a discovered node +type NodeStatus string + +const ( + NodeStatusActive NodeStatus = "active" + NodeStatusInactive NodeStatus = "inactive" + NodeStatusStale NodeStatus = "stale" +) + +// NodeInfo represents information about a discovered SPORE node +type NodeInfo struct { + IP string `json:"ip"` + Port int `json:"port"` + Hostname string `json:"hostname"` + Status NodeStatus `json:"status"` + DiscoveredAt time.Time `json:"discoveredAt"` + LastSeen time.Time `json:"lastSeen"` + Uptime string `json:"uptime,omitempty"` + Labels map[string]string `json:"labels,omitempty"` + Latency int64 `json:"latency,omitempty"` + Resources map[string]interface{} `json:"resources,omitempty"` +} + +// ClusterStatus represents the current cluster state +type ClusterStatus struct { + PrimaryNode string `json:"primaryNode"` + TotalNodes int `json:"totalNodes"` + UDPPort string `json:"udpPort"` + ServerRunning bool `json:"serverRunning"` +} + +// NodeUpdateCallback is called when node information changes +type NodeUpdateCallback func(nodeIP string, action string) + +// NodeDiscovery manages UDP-based node discovery +type NodeDiscovery struct { + udpPort string + discoveredNodes map[string]*NodeInfo + primaryNode string + mutex sync.RWMutex + callbacks []NodeUpdateCallback + staleThreshold time.Duration + logger *log.Logger +} + +// NewNodeDiscovery creates a new node discovery instance +func NewNodeDiscovery(udpPort string) *NodeDiscovery { + return &NodeDiscovery{ + udpPort: udpPort, + discoveredNodes: make(map[string]*NodeInfo), + staleThreshold: 8 * time.Second, // 8 seconds to accommodate 5-second heartbeat interval + logger: log.New(), + } +} diff --git a/internal/server/server.go b/internal/server/server.go new file mode 100644 index 0000000..7c4600a --- /dev/null +++ b/internal/server/server.go @@ -0,0 +1,745 @@ +package server + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "spore-gateway/internal/discovery" + "spore-gateway/internal/websocket" + "spore-gateway/pkg/client" + + "github.com/gorilla/mux" + log "github.com/sirupsen/logrus" +) + +// HTTPServer represents the HTTP server +type HTTPServer struct { + port string + router *mux.Router + nodeDiscovery *discovery.NodeDiscovery + sporeClients map[string]*client.SporeClient + webSocketServer *websocket.WebSocketServer + server *http.Server +} + +// NewHTTPServer creates a new HTTP server instance +func NewHTTPServer(port string, nodeDiscovery *discovery.NodeDiscovery) *HTTPServer { + // Initialize WebSocket server + wsServer := websocket.NewWebSocketServer(nodeDiscovery) + + hs := &HTTPServer{ + port: port, + router: mux.NewRouter(), + nodeDiscovery: nodeDiscovery, + sporeClients: make(map[string]*client.SporeClient), + webSocketServer: wsServer, + } + + hs.setupRoutes() + hs.setupMiddleware() + + hs.server = &http.Server{ + Addr: ":" + port, + Handler: hs.router, + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 60 * time.Second, + } + + return hs +} + +// setupMiddleware configures middleware for the server +func (hs *HTTPServer) setupMiddleware() { + // CORS middleware + hs.router.Use(hs.corsMiddleware) + + // JSON middleware + hs.router.Use(hs.jsonMiddleware) + + // Logging middleware + hs.router.Use(hs.loggingMiddleware) +} + +// corsMiddleware handles CORS headers +func (hs *HTTPServer) corsMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") + + if r.Method == "OPTIONS" { + w.WriteHeader(http.StatusOK) + return + } + + next.ServeHTTP(w, r) + }) +} + +// jsonMiddleware sets JSON content type +func (hs *HTTPServer) jsonMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + next.ServeHTTP(w, r) + }) +} + +// loggingMiddleware logs HTTP requests +func (hs *HTTPServer) loggingMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + next.ServeHTTP(w, r) + log.WithFields(log.Fields{ + "method": r.Method, + "path": r.URL.Path, + "remote_addr": r.RemoteAddr, + "user_agent": r.UserAgent(), + "duration": time.Since(start), + }).Debug("HTTP request") + }) +} + +// setupRoutes configures all the API routes +func (hs *HTTPServer) setupRoutes() { + // API routes + api := hs.router.PathPrefix("/api").Subrouter() + + // Apply CORS middleware to API subrouter as well + api.Use(hs.corsMiddleware) + + // Discovery endpoints + api.HandleFunc("/discovery/nodes", hs.getDiscoveryNodes).Methods("GET") + api.HandleFunc("/discovery/refresh", hs.refreshDiscovery).Methods("POST", "OPTIONS") + api.HandleFunc("/discovery/random-primary", hs.selectRandomPrimary).Methods("POST", "OPTIONS") + api.HandleFunc("/discovery/primary/{ip}", hs.setPrimaryNode).Methods("POST", "OPTIONS") + + // Cluster endpoints + api.HandleFunc("/cluster/members", hs.getClusterMembers).Methods("GET") + api.HandleFunc("/cluster/refresh", hs.refreshCluster).Methods("POST", "OPTIONS") + + // Task endpoints + api.HandleFunc("/tasks/status", hs.getTaskStatus).Methods("GET") + + // Node endpoints + api.HandleFunc("/node/status", hs.getNodeStatus).Methods("GET") + api.HandleFunc("/node/status/{ip}", hs.getNodeStatusByIP).Methods("GET") + api.HandleFunc("/node/endpoints", hs.getNodeEndpoints).Methods("GET") + api.HandleFunc("/node/update", hs.updateNodeFirmware).Methods("POST", "OPTIONS") + + // Proxy endpoints + api.HandleFunc("/proxy-call", hs.proxyCall).Methods("POST", "OPTIONS") + + // Test endpoints + api.HandleFunc("/test/websocket", hs.testWebSocket).Methods("POST", "OPTIONS") + + // Health check + api.HandleFunc("/health", hs.healthCheck).Methods("GET") + + // WebSocket endpoint - apply CORS middleware + hs.router.HandleFunc("/ws", hs.corsMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := hs.webSocketServer.HandleWebSocket(w, r); err != nil { + log.WithError(err).Error("WebSocket connection failed") + http.Error(w, "WebSocket upgrade failed", http.StatusBadRequest) + } + })).ServeHTTP) +} + +// Start starts the HTTP server +func (hs *HTTPServer) Start() error { + log.WithField("port", hs.port).Info("Starting HTTP server") + return hs.server.ListenAndServe() +} + +// Shutdown gracefully shuts down the HTTP server +func (hs *HTTPServer) Shutdown(ctx context.Context) error { + log.Info("Shutting down HTTP server") + + // Shutdown WebSocket server + if err := hs.webSocketServer.Shutdown(ctx); err != nil { + log.WithError(err).Error("WebSocket server shutdown error") + } + + return hs.server.Shutdown(ctx) +} + +// Helper function to get or create SPORE client for a node +func (hs *HTTPServer) getSporeClient(nodeIP string) *client.SporeClient { + if client, exists := hs.sporeClients[nodeIP]; exists { + return client + } + + client := client.NewSporeClient(fmt.Sprintf("http://%s", nodeIP)) + hs.sporeClients[nodeIP] = client + return client +} + +// Helper function to perform operation with failover +func (hs *HTTPServer) performWithFailover(operation func(*client.SporeClient) (interface{}, error)) (interface{}, error) { + primaryNode := hs.nodeDiscovery.GetPrimaryNode() + nodes := hs.nodeDiscovery.GetNodes() + + if len(nodes) == 0 { + return nil, fmt.Errorf("no SPORE nodes discovered") + } + + // Build candidate list: primary first, then others by most recently seen + var candidateIPs []string + if primaryNode != "" { + if _, exists := nodes[primaryNode]; exists { + candidateIPs = append(candidateIPs, primaryNode) + } + } + + for _, node := range nodes { + if node.IP != primaryNode { + candidateIPs = append(candidateIPs, node.IP) + } + } + + var lastError error + for _, ip := range candidateIPs { + client := hs.getSporeClient(ip) + result, err := operation(client) + if err == nil { + // Success - if this wasn't the primary, switch to it + if ip != primaryNode && primaryNode != "" { + hs.nodeDiscovery.SetPrimaryNode(ip) + log.WithField("ip", ip).Info("Failover: switched primary node") + } + return result, nil + } + + log.WithFields(log.Fields{ + "ip": ip, + "err": err, + }).Warn("Primary attempt failed") + lastError = err + } + + return nil, lastError +} + +// API endpoint handlers + +// GET /api/discovery/nodes +func (hs *HTTPServer) getDiscoveryNodes(w http.ResponseWriter, r *http.Request) { + nodes := hs.nodeDiscovery.GetNodes() + primaryNode := hs.nodeDiscovery.GetPrimaryNode() + clusterStatus := hs.nodeDiscovery.GetClusterStatus() + + // Create response with enhanced node info including IsPrimary + type NodeResponse struct { + *discovery.NodeInfo + IsPrimary bool `json:"isPrimary"` + } + + response := struct { + PrimaryNode string `json:"primaryNode"` + TotalNodes int `json:"totalNodes"` + Nodes []NodeResponse `json:"nodes"` + ClientInitialized bool `json:"clientInitialized"` + ClientBaseURL string `json:"clientBaseUrl"` + ClusterStatus discovery.ClusterStatus `json:"clusterStatus"` + }{ + PrimaryNode: primaryNode, + TotalNodes: len(nodes), + Nodes: make([]NodeResponse, 0, len(nodes)), + ClientInitialized: primaryNode != "", + ClientBaseURL: "", + ClusterStatus: clusterStatus, + } + + for _, node := range nodes { + nodeResponse := NodeResponse{ + NodeInfo: node, + IsPrimary: node.IP == primaryNode, + } + response.Nodes = append(response.Nodes, nodeResponse) + } + + json.NewEncoder(w).Encode(response) +} + +// POST /api/discovery/refresh +func (hs *HTTPServer) refreshDiscovery(w http.ResponseWriter, r *http.Request) { + // Mark stale nodes and update primary if needed + // The node discovery system handles this automatically via its cleanup routine + + response := struct { + Success bool `json:"success"` + Message string `json:"message"` + PrimaryNode string `json:"primaryNode"` + TotalNodes int `json:"totalNodes"` + ClientInitialized bool `json:"clientInitialized"` + }{ + Success: true, + Message: "Cluster refresh completed", + PrimaryNode: hs.nodeDiscovery.GetPrimaryNode(), + TotalNodes: len(hs.nodeDiscovery.GetNodes()), + ClientInitialized: hs.nodeDiscovery.GetPrimaryNode() != "", + } + + json.NewEncoder(w).Encode(response) +} + +// POST /api/discovery/random-primary +func (hs *HTTPServer) selectRandomPrimary(w http.ResponseWriter, r *http.Request) { + nodes := hs.nodeDiscovery.GetNodes() + if len(nodes) == 0 { + http.Error(w, `{"error": "No nodes available", "message": "No SPORE nodes have been discovered yet"}`, http.StatusNotFound) + return + } + + newPrimary := hs.nodeDiscovery.SelectRandomPrimaryNode() + if newPrimary == "" { + http.Error(w, `{"error": "Selection failed", "message": "Failed to select a random primary node"}`, http.StatusInternalServerError) + return + } + + response := struct { + Success bool `json:"success"` + Message string `json:"message"` + PrimaryNode string `json:"primaryNode"` + TotalNodes int `json:"totalNodes"` + ClientInitialized bool `json:"clientInitialized"` + Timestamp string `json:"timestamp"` + }{ + Success: true, + Message: fmt.Sprintf("Randomly selected new primary node: %s", newPrimary), + PrimaryNode: newPrimary, + TotalNodes: len(nodes), + ClientInitialized: true, + Timestamp: time.Now().Format(time.RFC3339), + } + + json.NewEncoder(w).Encode(response) +} + +// POST /api/discovery/primary/{ip} +func (hs *HTTPServer) setPrimaryNode(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + requestedIP := vars["ip"] + + if err := hs.nodeDiscovery.SetPrimaryNode(requestedIP); err != nil { + http.Error(w, fmt.Sprintf(`{"error": "Node not found", "message": "Node with IP %s has not been discovered"}`, requestedIP), http.StatusNotFound) + return + } + + response := struct { + Success bool `json:"success"` + Message string `json:"message"` + PrimaryNode string `json:"primaryNode"` + ClientInitialized bool `json:"clientInitialized"` + }{ + Success: true, + Message: fmt.Sprintf("Primary node set to %s", requestedIP), + PrimaryNode: requestedIP, + ClientInitialized: true, + } + + json.NewEncoder(w).Encode(response) +} + +// GET /api/cluster/members +func (hs *HTTPServer) getClusterMembers(w http.ResponseWriter, r *http.Request) { + result, err := hs.performWithFailover(func(client *client.SporeClient) (interface{}, error) { + return client.GetClusterStatus() + }) + + if err != nil { + log.WithError(err).Error("Error fetching cluster members") + http.Error(w, fmt.Sprintf(`{"error": "Failed to fetch cluster members", "message": "%s"}`, err.Error()), http.StatusBadGateway) + return + } + + json.NewEncoder(w).Encode(result) +} + +// POST /api/cluster/refresh +func (hs *HTTPServer) refreshCluster(w http.ResponseWriter, r *http.Request) { + var requestBody struct { + Reason string `json:"reason"` + } + + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil && err.Error() != "EOF" { + http.Error(w, `{"error": "Invalid JSON", "message": "Failed to parse request body"}`, http.StatusBadRequest) + return + } + + reason := requestBody.Reason + if reason == "" { + reason = "manual_refresh" + } + + log.WithField("reason", reason).Info("Manual cluster refresh triggered") + + response := struct { + Success bool `json:"success"` + Message string `json:"message"` + Reason string `json:"reason"` + WSclients int `json:"wsClients"` + }{ + Success: true, + Message: "Cluster refresh triggered", + Reason: reason, + WSclients: hs.webSocketServer.GetClientCount(), + } + + json.NewEncoder(w).Encode(response) +} + +// GET /api/tasks/status +func (hs *HTTPServer) getTaskStatus(w http.ResponseWriter, r *http.Request) { + ip := r.URL.Query().Get("ip") + + if ip != "" { + client := hs.getSporeClient(ip) + result, err := client.GetTaskStatus() + if err != nil { + log.WithError(err).Error("Error fetching task status from specific node") + http.Error(w, fmt.Sprintf(`{"error": "Failed to fetch task status from node", "message": "%s"}`, err.Error()), http.StatusInternalServerError) + return + } + json.NewEncoder(w).Encode(result) + return + } + + result, err := hs.performWithFailover(func(client *client.SporeClient) (interface{}, error) { + return client.GetTaskStatus() + }) + + if err != nil { + log.WithError(err).Error("Error fetching task status") + http.Error(w, fmt.Sprintf(`{"error": "Failed to fetch task status", "message": "%s"}`, err.Error()), http.StatusBadGateway) + return + } + + json.NewEncoder(w).Encode(result) +} + +// GET /api/node/status +func (hs *HTTPServer) getNodeStatus(w http.ResponseWriter, r *http.Request) { + result, err := hs.performWithFailover(func(client *client.SporeClient) (interface{}, error) { + return client.GetSystemStatus() + }) + + if err != nil { + log.WithError(err).Error("Error fetching system status") + http.Error(w, fmt.Sprintf(`{"error": "Failed to fetch system status", "message": "%s"}`, err.Error()), http.StatusBadGateway) + return + } + + json.NewEncoder(w).Encode(result) +} + +// GET /api/node/status/{ip} +func (hs *HTTPServer) getNodeStatusByIP(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + nodeIP := vars["ip"] + + client := hs.getSporeClient(nodeIP) + result, err := client.GetSystemStatus() + if err != nil { + log.WithError(err).Error("Error fetching status from specific node") + http.Error(w, fmt.Sprintf(`{"error": "Failed to fetch status from node %s", "message": "%s"}`, nodeIP, err.Error()), http.StatusInternalServerError) + return + } + + json.NewEncoder(w).Encode(result) +} + +// GET /api/node/endpoints +func (hs *HTTPServer) getNodeEndpoints(w http.ResponseWriter, r *http.Request) { + ip := r.URL.Query().Get("ip") + + if ip != "" { + client := hs.getSporeClient(ip) + result, err := client.GetCapabilities() + if err != nil { + log.WithError(err).Error("Error fetching endpoints from specific node") + http.Error(w, fmt.Sprintf(`{"error": "Failed to fetch endpoints from node", "message": "%s"}`, err.Error()), http.StatusInternalServerError) + return + } + json.NewEncoder(w).Encode(result) + return + } + + result, err := hs.performWithFailover(func(client *client.SporeClient) (interface{}, error) { + return client.GetCapabilities() + }) + + if err != nil { + log.WithError(err).Error("Error fetching capabilities") + http.Error(w, fmt.Sprintf(`{"error": "Failed to fetch capabilities", "message": "%s"}`, err.Error()), http.StatusBadGateway) + return + } + + json.NewEncoder(w).Encode(result) +} + +// POST /api/node/update +func (hs *HTTPServer) updateNodeFirmware(w http.ResponseWriter, r *http.Request) { + nodeIP := r.URL.Query().Get("ip") + if nodeIP == "" { + nodeIP = r.Header.Get("X-Node-IP") + } + + if nodeIP == "" { + http.Error(w, `{"error": "Node IP address is required", "message": "Please provide the target node IP address"}`, http.StatusBadRequest) + return + } + + // Parse multipart form + err := r.ParseMultipartForm(50 << 20) // 50MB limit + if err != nil { + log.WithError(err).Error("Error parsing multipart form") + http.Error(w, `{"error": "Failed to parse form", "message": "Error parsing multipart form data"}`, http.StatusBadRequest) + return + } + + file, fileHeader, err := r.FormFile("file") + if err != nil { + log.WithError(err).Error("No file found in form") + http.Error(w, `{"error": "No file data received", "message": "Please select a firmware file to upload"}`, http.StatusBadRequest) + return + } + defer file.Close() + + // Get the original filename + filename := fileHeader.Filename + if filename == "" { + filename = "firmware.bin" + } + + // Read file data + fileData := make([]byte, 0) + buffer := make([]byte, 1024) + for { + n, err := file.Read(buffer) + if n > 0 { + fileData = append(fileData, buffer[:n]...) + } + if err != nil { + if err.Error() == "EOF" { + break + } + log.WithError(err).Error("Error reading file data") + http.Error(w, `{"error": "Failed to read file", "message": "Error reading uploaded file data"}`, http.StatusInternalServerError) + return + } + } + + log.WithFields(log.Fields{ + "node_ip": nodeIP, + "file_size": len(fileData), + }).Info("Firmware upload received") + + client := hs.getSporeClient(nodeIP) + result, err := client.UpdateFirmware(fileData, filename) + if err != nil { + log.WithError(err).Error("Error uploading firmware") + http.Error(w, fmt.Sprintf(`{"error": "Failed to upload firmware", "message": "%s"}`, err.Error()), http.StatusInternalServerError) + return + } + + // Check if the device reported a failure + if result.Status == "FAIL" { + log.WithField("message", result.Message).Error("Device reported firmware update failure") + http.Error(w, fmt.Sprintf(`{"success": false, "error": "Firmware update failed", "message": "%s"}`, result.Message), http.StatusBadRequest) + return + } + + response := struct { + Success bool `json:"success"` + Message string `json:"message"` + NodeIP string `json:"nodeIp"` + FileSize int `json:"fileSize"` + Filename string `json:"filename"` + Result interface{} `json:"result"` + }{ + Success: true, + Message: "Firmware uploaded successfully", + NodeIP: nodeIP, + FileSize: len(fileData), + Filename: filename, + Result: result, + } + + json.NewEncoder(w).Encode(response) +} + +// POST /api/proxy-call +func (hs *HTTPServer) proxyCall(w http.ResponseWriter, r *http.Request) { + var requestBody struct { + IP string `json:"ip"` + Method string `json:"method"` + URI string `json:"uri"` + Params []map[string]interface{} `json:"params"` + } + + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + http.Error(w, `{"error": "Invalid JSON", "message": "Failed to parse request body"}`, http.StatusBadRequest) + return + } + + if requestBody.IP == "" || requestBody.Method == "" || requestBody.URI == "" { + http.Error(w, `{"error": "Missing required fields", "message": "Required: ip, method, uri"}`, http.StatusBadRequest) + return + } + + // Convert params to map for client + params := make(map[string]interface{}) + for _, param := range requestBody.Params { + if name, ok := param["name"].(string); ok { + // Create parameter object with type and location info + paramObj := map[string]interface{}{ + "location": "body", + "type": "string", // default type + } + + // Extract the actual value from the parameter object + if value, ok := param["value"]; ok { + paramObj["value"] = value + } else { + paramObj["value"] = param + } + + // Check if we have type information from the endpoint definition + // For now, we'll detect JSON by checking if the value is a JSON string + if value, ok := paramObj["value"].(string); ok { + // Special handling for labels parameter - it expects raw JSON string + if name == "labels" { + paramObj["type"] = "json" + // Keep the value as string for labels parameter + } else { + // Try to parse as JSON to detect if it's a JSON parameter + var jsonValue interface{} + if err := json.Unmarshal([]byte(value), &jsonValue); err == nil { + paramObj["type"] = "json" + paramObj["value"] = jsonValue + } + } + } + + params[name] = paramObj + } + } + + client := hs.getSporeClient(requestBody.IP) + resp, err := client.ProxyCall(requestBody.Method, requestBody.URI, params) + if err != nil { + log.WithError(err).Error("Error in proxy call") + http.Error(w, fmt.Sprintf(`{"error": "Proxy call failed", "message": "%s"}`, err.Error()), http.StatusInternalServerError) + return + } + defer resp.Body.Close() + + // Read response + body, err := io.ReadAll(resp.Body) + if err != nil { + log.WithError(err).Error("Error reading proxy response") + http.Error(w, `{"error": "Failed to read response", "message": "Error reading upstream response"}`, http.StatusInternalServerError) + return + } + + // Set appropriate content type + contentType := resp.Header.Get("Content-Type") + if contentType != "" { + w.Header().Set("Content-Type", contentType) + } + + // Set CORS headers for proxy responses + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") + + // Set status code + w.WriteHeader(resp.StatusCode) + + // For consistency with frontend expectations, wrap response in data field for JSON responses + if contentType != "" && strings.Contains(contentType, "application/json") { + // Try to parse and re-wrap the JSON response + var jsonResp interface{} + if err := json.Unmarshal(body, &jsonResp); err == nil { + wrappedResp := map[string]interface{}{ + "data": jsonResp, + "status": resp.StatusCode, + } + body, _ = json.Marshal(wrappedResp) + } + } + + // Write response body + w.Write(body) +} + +// POST /api/test/websocket +func (hs *HTTPServer) testWebSocket(w http.ResponseWriter, r *http.Request) { + log.Info("Manual WebSocket test triggered") + + response := struct { + Success bool `json:"success"` + Message string `json:"message"` + WSclients int `json:"websocketClients"` + TotalNodes int `json:"totalNodes"` + }{ + Success: true, + Message: "WebSocket test broadcast sent", + WSclients: hs.webSocketServer.GetClientCount(), + TotalNodes: len(hs.nodeDiscovery.GetNodes()), + } + + json.NewEncoder(w).Encode(response) +} + +// GET /api/health +func (hs *HTTPServer) healthCheck(w http.ResponseWriter, r *http.Request) { + primaryNode := hs.nodeDiscovery.GetPrimaryNode() + nodes := hs.nodeDiscovery.GetNodes() + clusterStatus := hs.nodeDiscovery.GetClusterStatus() + + health := struct { + Status string `json:"status"` + Timestamp string `json:"timestamp"` + Services map[string]bool `json:"services"` + Cluster map[string]interface{} `json:"cluster"` + }{ + Status: "healthy", + Timestamp: time.Now().Format(time.RFC3339), + Services: map[string]bool{ + "http": true, + "udp": clusterStatus.ServerRunning, + "sporeClient": primaryNode != "", + }, + Cluster: map[string]interface{}{ + "totalNodes": clusterStatus.TotalNodes, + "primaryNode": clusterStatus.PrimaryNode, + "udpPort": clusterStatus.UDPPort, + "serverRunning": clusterStatus.ServerRunning, + }, + } + + // Mark as degraded if no nodes discovered + if len(nodes) == 0 { + health.Status = "degraded" + } + + // Mark as degraded if no client initialized + if primaryNode == "" { + health.Status = "degraded" + } + + statusCode := http.StatusOK + if health.Status != "healthy" { + statusCode = http.StatusServiceUnavailable + } + + w.WriteHeader(statusCode) + json.NewEncoder(w).Encode(health) +} diff --git a/internal/websocket/websocket.go b/internal/websocket/websocket.go new file mode 100644 index 0000000..8ddd7c9 --- /dev/null +++ b/internal/websocket/websocket.go @@ -0,0 +1,370 @@ +package websocket + +import ( + "context" + "encoding/json" + "net/http" + "sync" + "time" + + "spore-gateway/internal/discovery" + "spore-gateway/pkg/client" + + "github.com/gorilla/websocket" + log "github.com/sirupsen/logrus" +) + +var upgrader = websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true // Allow connections from any origin in development + }, +} + +// WebSocketServer manages WebSocket connections and broadcasts +type WebSocketServer struct { + nodeDiscovery *discovery.NodeDiscovery + sporeClients map[string]*client.SporeClient + clients map[*websocket.Conn]bool + mutex sync.RWMutex + logger *log.Logger +} + +// NewWebSocketServer creates a new WebSocket server +func NewWebSocketServer(nodeDiscovery *discovery.NodeDiscovery) *WebSocketServer { + wss := &WebSocketServer{ + nodeDiscovery: nodeDiscovery, + sporeClients: make(map[string]*client.SporeClient), + clients: make(map[*websocket.Conn]bool), + logger: log.New(), + } + + // Register callback for node updates + nodeDiscovery.AddCallback(wss.handleNodeUpdate) + + return wss +} + +// HandleWebSocket handles WebSocket upgrade and connection +func (wss *WebSocketServer) HandleWebSocket(w http.ResponseWriter, r *http.Request) error { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + wss.logger.WithError(err).Error("Failed to upgrade WebSocket connection") + return err + } + + wss.mutex.Lock() + wss.clients[conn] = true + wss.mutex.Unlock() + + wss.logger.Debug("WebSocket client connected") + + // Send current cluster state to newly connected client + go wss.sendCurrentClusterState(conn) + + // Handle client messages and disconnection + go wss.handleClient(conn) + + return nil +} + +// handleClient handles messages from a WebSocket client +func (wss *WebSocketServer) handleClient(conn *websocket.Conn) { + defer func() { + wss.mutex.Lock() + delete(wss.clients, conn) + wss.mutex.Unlock() + conn.Close() + wss.logger.Debug("WebSocket client disconnected") + }() + + // Set read deadline and pong handler + conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + conn.SetPongHandler(func(string) error { + conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + return nil + }) + + // Start ping routine + go func() { + ticker := time.NewTicker(54 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { + return + } + } + } + }() + + // Read messages (we don't expect any, but this keeps the connection alive) + for { + _, _, err := conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + wss.logger.WithError(err).Error("WebSocket error") + } + break + } + } +} + +// sendCurrentClusterState sends the current cluster state to a newly connected client +func (wss *WebSocketServer) sendCurrentClusterState(conn *websocket.Conn) { + nodes := wss.nodeDiscovery.GetNodes() + if len(nodes) == 0 { + return + } + + // Get real cluster data from SPORE nodes + clusterData, err := wss.getCurrentClusterMembers() + if err != nil { + wss.logger.WithError(err).Error("Failed to get cluster data for WebSocket") + return + } + + message := struct { + Type string `json:"type"` + Members []client.ClusterMember `json:"members"` + PrimaryNode string `json:"primaryNode"` + TotalNodes int `json:"totalNodes"` + Timestamp string `json:"timestamp"` + }{ + Type: "cluster_update", + Members: clusterData, + PrimaryNode: wss.nodeDiscovery.GetPrimaryNode(), + TotalNodes: len(nodes), + Timestamp: time.Now().Format(time.RFC3339), + } + + data, err := json.Marshal(message) + if err != nil { + wss.logger.WithError(err).Error("Failed to marshal cluster data") + return + } + + conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + if err := conn.WriteMessage(websocket.TextMessage, data); err != nil { + wss.logger.WithError(err).Error("Failed to send initial cluster state") + } +} + +// handleNodeUpdate is called when node information changes +func (wss *WebSocketServer) handleNodeUpdate(nodeIP, action string) { + wss.logger.WithFields(log.Fields{ + "node_ip": nodeIP, + "action": action, + }).Debug("Node update received, broadcasting to WebSocket clients") + + // Broadcast cluster update to all clients + wss.broadcastClusterUpdate() + + // Also broadcast node discovery event + wss.broadcastNodeDiscovery(nodeIP, action) +} + +// broadcastClusterUpdate sends cluster updates to all connected clients +func (wss *WebSocketServer) broadcastClusterUpdate() { + wss.mutex.RLock() + clients := make([]*websocket.Conn, 0, len(wss.clients)) + for client := range wss.clients { + clients = append(clients, client) + } + wss.mutex.RUnlock() + + if len(clients) == 0 { + return + } + + startTime := time.Now() + + // Get cluster members asynchronously + clusterData, err := wss.getCurrentClusterMembers() + if err != nil { + wss.logger.WithError(err).Error("Failed to get cluster data for broadcast") + return + } + + message := struct { + Type string `json:"type"` + Members []client.ClusterMember `json:"members"` + PrimaryNode string `json:"primaryNode"` + TotalNodes int `json:"totalNodes"` + Timestamp string `json:"timestamp"` + }{ + Type: "cluster_update", + Members: clusterData, + PrimaryNode: wss.nodeDiscovery.GetPrimaryNode(), + TotalNodes: len(wss.nodeDiscovery.GetNodes()), + Timestamp: time.Now().Format(time.RFC3339), + } + + data, err := json.Marshal(message) + if err != nil { + wss.logger.WithError(err).Error("Failed to marshal cluster update") + return + } + + broadcastTime := time.Now() + wss.logger.WithFields(log.Fields{ + "clients": len(clients), + "prep_time": broadcastTime.Sub(startTime), + }).Debug("Broadcasting cluster update to WebSocket clients") + + // Send to all clients + var failedClients int + for _, client := range clients { + client.SetWriteDeadline(time.Now().Add(5 * time.Second)) + if err := client.WriteMessage(websocket.TextMessage, data); err != nil { + wss.logger.WithError(err).Error("Failed to send cluster update to client") + failedClients++ + } + } + + totalTime := time.Now().Sub(startTime) + wss.logger.WithFields(log.Fields{ + "clients": len(clients), + "failed_clients": failedClients, + "total_time": totalTime, + }).Debug("Cluster update broadcast completed") +} + +// broadcastNodeDiscovery sends node discovery events to all clients +func (wss *WebSocketServer) broadcastNodeDiscovery(nodeIP, action string) { + wss.mutex.RLock() + clients := make([]*websocket.Conn, 0, len(wss.clients)) + for client := range wss.clients { + clients = append(clients, client) + } + wss.mutex.RUnlock() + + if len(clients) == 0 { + return + } + + message := struct { + Type string `json:"type"` + Action string `json:"action"` + NodeIP string `json:"nodeIp"` + Timestamp string `json:"timestamp"` + }{ + Type: "node_discovery", + Action: action, + NodeIP: nodeIP, + Timestamp: time.Now().Format(time.RFC3339), + } + + data, err := json.Marshal(message) + if err != nil { + wss.logger.WithError(err).Error("Failed to marshal node discovery event") + return + } + + for _, client := range clients { + client.SetWriteDeadline(time.Now().Add(5 * time.Second)) + if err := client.WriteMessage(websocket.TextMessage, data); err != nil { + wss.logger.WithError(err).Error("Failed to send node discovery event to client") + } + } +} + +// getCurrentClusterMembers fetches real cluster data from SPORE nodes +func (wss *WebSocketServer) getCurrentClusterMembers() ([]client.ClusterMember, error) { + nodes := wss.nodeDiscovery.GetNodes() + if len(nodes) == 0 { + return []client.ClusterMember{}, nil + } + + // Try to get real cluster data from primary node + primaryNode := wss.nodeDiscovery.GetPrimaryNode() + if primaryNode != "" { + client := wss.getSporeClient(primaryNode) + clusterStatus, err := client.GetClusterStatus() + if err == nil { + // Update local node data with API information + wss.updateLocalNodesWithAPI(clusterStatus.Members) + return clusterStatus.Members, nil + } + wss.logger.WithError(err).Error("Failed to get cluster status from primary node") + } + + // Fallback to local data if API fails + return wss.getFallbackClusterMembers(), nil +} + +// updateLocalNodesWithAPI updates local node data with information from API +func (wss *WebSocketServer) updateLocalNodesWithAPI(apiMembers []client.ClusterMember) { + // This would update the local node discovery with fresh API data + // For now, we'll just log that we received the data + wss.logger.WithField("members", len(apiMembers)).Debug("Updating local nodes with API data") + + for _, member := range apiMembers { + if member.Labels != nil && len(member.Labels) > 0 { + wss.logger.WithFields(log.Fields{ + "ip": member.IP, + "labels": member.Labels, + }).Debug("API member labels") + } + } +} + +// getFallbackClusterMembers returns local node data as fallback +func (wss *WebSocketServer) getFallbackClusterMembers() []client.ClusterMember { + nodes := wss.nodeDiscovery.GetNodes() + members := make([]client.ClusterMember, 0, len(nodes)) + + for _, node := range nodes { + member := client.ClusterMember{ + IP: node.IP, + Hostname: node.Hostname, + Status: string(node.Status), + Latency: node.Latency, + LastSeen: node.LastSeen.Unix(), + Labels: node.Labels, + } + members = append(members, member) + } + + return members +} + +// getSporeClient gets or creates a SPORE client for a node +func (wss *WebSocketServer) getSporeClient(nodeIP string) *client.SporeClient { + if client, exists := wss.sporeClients[nodeIP]; exists { + return client + } + + client := client.NewSporeClient("http://" + nodeIP) + wss.sporeClients[nodeIP] = client + return client +} + +// GetClientCount returns the number of connected WebSocket clients +func (wss *WebSocketServer) GetClientCount() int { + wss.mutex.RLock() + defer wss.mutex.RUnlock() + return len(wss.clients) +} + +// Shutdown gracefully shuts down the WebSocket server +func (wss *WebSocketServer) Shutdown(ctx context.Context) error { + wss.logger.Info("Shutting down WebSocket server") + + wss.mutex.Lock() + clients := make([]*websocket.Conn, 0, len(wss.clients)) + for client := range wss.clients { + clients = append(clients, client) + } + wss.mutex.Unlock() + + // Close all client connections + for _, client := range clients { + client.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, "Server shutting down")) + client.Close() + } + + return nil +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..d79afdb --- /dev/null +++ b/main.go @@ -0,0 +1,104 @@ +package main + +import ( + "context" + "flag" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "spore-gateway/internal/discovery" + "spore-gateway/internal/server" + "spore-gateway/pkg/config" + + log "github.com/sirupsen/logrus" +) + +func main() { + // Parse command line flags + configFile := flag.String("config", "", "Path to configuration file") + port := flag.String("port", "3001", "HTTP server port") + udpPort := flag.String("udp-port", "4210", "UDP discovery port") + logLevel := flag.String("log-level", "info", "Log level (debug, info, warn, error)") + flag.Parse() + + // Initialize logger + level, err := log.ParseLevel(*logLevel) + if err != nil { + log.WithError(err).Fatal("Invalid log level") + } + log.SetLevel(level) + log.SetFormatter(&log.TextFormatter{ + FullTimestamp: true, + }) + + log.Info("Starting SPORE Gateway") + + // Load configuration + cfg := config.Load(*configFile) + if cfg == nil { + log.Fatal("Failed to load configuration") + } + + // Override config with command line arguments + if *port != "3001" { + cfg.HTTPPort = *port + } + if *udpPort != "4210" { + cfg.UDPPort = *udpPort + } + + log.WithFields(log.Fields{ + "http_port": cfg.HTTPPort, + "udp_port": cfg.UDPPort, + }).Info("Configuration loaded") + + // Initialize node discovery + nodeDiscovery := discovery.NewNodeDiscovery(cfg.UDPPort) + + // Initialize HTTP server + httpServer := server.NewHTTPServer(cfg.HTTPPort, nodeDiscovery) + + // Setup graceful shutdown + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + // Start UDP discovery server + go func() { + if err := nodeDiscovery.Start(); err != nil { + log.WithError(err).Fatal("Failed to start node discovery") + } + }() + + // Start HTTP server + go func() { + log.WithField("port", cfg.HTTPPort).Info("Starting HTTP server") + if err := httpServer.Start(); err != nil && err != http.ErrServerClosed { + log.WithError(err).Fatal("Failed to start HTTP server") + } + }() + + // Wait for interrupt signal + <-ctx.Done() + stop() + + log.Info("Shutting down servers...") + + // Create shutdown context with timeout + shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Shutdown HTTP server + if err := httpServer.Shutdown(shutdownCtx); err != nil { + log.WithError(err).Error("HTTP server shutdown error") + } + + // Shutdown node discovery + if err := nodeDiscovery.Shutdown(shutdownCtx); err != nil { + log.WithError(err).Error("Node discovery shutdown error") + } + + log.Info("Shutdown complete") +} diff --git a/pkg/client/client.go b/pkg/client/client.go new file mode 100644 index 0000000..addb134 --- /dev/null +++ b/pkg/client/client.go @@ -0,0 +1,399 @@ +package client + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "mime/multipart" + "net/http" + "net/url" + "strings" + "time" + + log "github.com/sirupsen/logrus" +) + +// SporeClient represents a client for communicating with SPORE nodes +type SporeClient struct { + BaseURL string + HTTPClient *http.Client +} + +// NewSporeClient creates a new SPORE API client +func NewSporeClient(baseURL string) *SporeClient { + return &SporeClient{ + BaseURL: baseURL, + HTTPClient: &http.Client{ + Timeout: 30 * time.Second, + }, + } +} + +// ClusterStatusResponse represents the response from /api/cluster/members +type ClusterStatusResponse struct { + Members []ClusterMember `json:"members"` +} + +// ClusterMember represents a member in the cluster +type ClusterMember struct { + IP string `json:"ip"` + Hostname string `json:"hostname"` + Status string `json:"status"` + Latency int64 `json:"latency"` + LastSeen int64 `json:"lastSeen"` + Labels map[string]string `json:"labels"` + Resources map[string]interface{} `json:"resources"` +} + +// TaskStatusResponse represents the response from /api/tasks/status +type TaskStatusResponse struct { + Summary TaskSummary `json:"summary"` + Tasks []TaskInfo `json:"tasks"` + System SystemInfo `json:"system"` +} + +// TaskSummary represents task summary information +type TaskSummary struct { + TotalTasks int `json:"totalTasks"` + ActiveTasks int `json:"activeTasks"` +} + +// TaskInfo represents information about a task +type TaskInfo struct { + Name string `json:"name"` + Interval int `json:"interval"` + Enabled bool `json:"enabled"` + Running bool `json:"running"` + AutoStart bool `json:"autoStart"` +} + +// SystemInfo represents system information +type SystemInfo struct { + FreeHeap int64 `json:"freeHeap"` + Uptime int64 `json:"uptime"` +} + +// SystemStatusResponse represents the response from /api/node/status +type SystemStatusResponse struct { + FreeHeap int64 `json:"freeHeap"` + ChipID int64 `json:"chipId"` + SDKVersion string `json:"sdkVersion"` + CPUFreqMHz int `json:"cpuFreqMHz"` + FlashChipSize int64 `json:"flashChipSize"` + Labels map[string]string `json:"labels"` +} + +// CapabilitiesResponse represents the response from /api/node/endpoints +type CapabilitiesResponse struct { + Endpoints []EndpointInfo `json:"endpoints"` +} + +// EndpointInfo represents information about an API endpoint +type EndpointInfo struct { + URI string `json:"uri"` + Method string `json:"method"` + Parameters []ParameterInfo `json:"params"` +} + +// ParameterInfo represents information about a parameter +type ParameterInfo struct { + Name string `json:"name"` + Type string `json:"type"` + Required bool `json:"required"` + Description string `json:"description"` + Location string `json:"location"` // query, path, body + Default string `json:"default,omitempty"` + Values []string `json:"values,omitempty"` +} + +// FirmwareUpdateResponse represents the response from firmware update +type FirmwareUpdateResponse struct { + Status string `json:"status"` + Message string `json:"message"` +} + +// GetClusterStatus retrieves cluster member information +func (c *SporeClient) GetClusterStatus() (*ClusterStatusResponse, error) { + url := fmt.Sprintf("%s/api/cluster/members", c.BaseURL) + + resp, err := c.HTTPClient.Get(url) + if err != nil { + return nil, fmt.Errorf("failed to get cluster status: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("cluster status request failed with status %d", resp.StatusCode) + } + + var clusterStatus ClusterStatusResponse + if err := json.NewDecoder(resp.Body).Decode(&clusterStatus); err != nil { + return nil, fmt.Errorf("failed to decode cluster status response: %w", err) + } + + return &clusterStatus, nil +} + +// GetTaskStatus retrieves task status information +func (c *SporeClient) GetTaskStatus() (*TaskStatusResponse, error) { + url := fmt.Sprintf("%s/api/tasks/status", c.BaseURL) + + resp, err := c.HTTPClient.Get(url) + if err != nil { + return nil, fmt.Errorf("failed to get task status: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("task status request failed with status %d", resp.StatusCode) + } + + var taskStatus TaskStatusResponse + if err := json.NewDecoder(resp.Body).Decode(&taskStatus); err != nil { + return nil, fmt.Errorf("failed to decode task status response: %w", err) + } + + return &taskStatus, nil +} + +// GetSystemStatus retrieves system status information +func (c *SporeClient) GetSystemStatus() (*SystemStatusResponse, error) { + url := fmt.Sprintf("%s/api/node/status", c.BaseURL) + + resp, err := c.HTTPClient.Get(url) + if err != nil { + return nil, fmt.Errorf("failed to get system status: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("system status request failed with status %d", resp.StatusCode) + } + + var systemStatus SystemStatusResponse + if err := json.NewDecoder(resp.Body).Decode(&systemStatus); err != nil { + return nil, fmt.Errorf("failed to decode system status response: %w", err) + } + + return &systemStatus, nil +} + +// GetCapabilities retrieves available API endpoints +func (c *SporeClient) GetCapabilities() (*CapabilitiesResponse, error) { + url := fmt.Sprintf("%s/api/node/endpoints", c.BaseURL) + + resp, err := c.HTTPClient.Get(url) + if err != nil { + return nil, fmt.Errorf("failed to get capabilities: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("capabilities request failed with status %d", resp.StatusCode) + } + + var capabilities CapabilitiesResponse + if err := json.NewDecoder(resp.Body).Decode(&capabilities); err != nil { + return nil, fmt.Errorf("failed to decode capabilities response: %w", err) + } + + return &capabilities, nil +} + +// UpdateFirmware uploads firmware to a SPORE node +func (c *SporeClient) UpdateFirmware(firmwareData []byte, filename string) (*FirmwareUpdateResponse, error) { + url := fmt.Sprintf("%s/api/node/update", c.BaseURL) + + // Create multipart form + var requestBody bytes.Buffer + contentType := createMultipartForm(&requestBody, firmwareData, filename) + + if contentType == "" { + return nil, fmt.Errorf("failed to create multipart form") + } + + req, err := http.NewRequest("POST", url, &requestBody) + if err != nil { + return nil, fmt.Errorf("failed to create firmware update request: %w", err) + } + + req.Header.Set("Content-Type", contentType) + + // Create a client with extended timeout for firmware uploads + firmwareClient := &http.Client{ + Timeout: 5 * time.Minute, // 5 minutes for firmware uploads + } + + resp, err := firmwareClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to upload firmware: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("firmware update failed with status %d: %s", resp.StatusCode, string(body)) + } + + var updateResponse FirmwareUpdateResponse + if err := json.NewDecoder(resp.Body).Decode(&updateResponse); err != nil { + return nil, fmt.Errorf("failed to decode firmware update response: %w", err) + } + + return &updateResponse, nil +} + +// ProxyCall makes a generic HTTP request to a SPORE node endpoint +func (c *SporeClient) ProxyCall(method, uri string, params map[string]interface{}) (*http.Response, error) { + // Build target URL + targetURL := fmt.Sprintf("%s%s", c.BaseURL, uri) + + // Parse parameters and build request + req, err := c.buildProxyRequest(method, targetURL, params) + if err != nil { + return nil, fmt.Errorf("failed to build proxy request: %w", err) + } + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return nil, fmt.Errorf("proxy call failed: %w", err) + } + + return resp, nil +} + +// buildProxyRequest builds an HTTP request for proxy calls +func (c *SporeClient) buildProxyRequest(method, targetURL string, params map[string]interface{}) (*http.Request, error) { + var body io.Reader + var contentType string + + if method != "GET" && params != nil { + // Check if we have JSON parameters + hasJSONParams := false + jsonParams := make(map[string]interface{}) + form := make(map[string][]string) + query := make(map[string][]string) + + for name, value := range params { + location := "body" + paramType := "string" + + // Check if value is a parameter object with location and type info + if paramObj, ok := value.(map[string]interface{}); ok { + if loc, exists := paramObj["location"].(string); exists { + location = loc + } + if ptype, exists := paramObj["type"].(string); exists { + paramType = ptype + } + // Extract the actual value + if val, exists := paramObj["value"]; exists { + value = val + } + } + + switch location { + case "query": + query[name] = append(query[name], fmt.Sprintf("%v", value)) + case "path": + // Replace {name} or :name in path + placeholder := fmt.Sprintf("{%s}", name) + if strings.Contains(targetURL, placeholder) { + targetURL = strings.ReplaceAll(targetURL, placeholder, fmt.Sprintf("%v", value)) + } + placeholder = fmt.Sprintf(":%s", name) + if strings.Contains(targetURL, placeholder) { + targetURL = strings.ReplaceAll(targetURL, placeholder, fmt.Sprintf("%v", value)) + } + default: + // Special handling for certain parameters that expect form-encoded data + // even when marked as "json" type + if paramType == "json" && name == "labels" { + // The labels parameter expects form-encoded data, not JSON + form[name] = append(form[name], fmt.Sprintf("%v", value)) + } else if paramType == "json" { + hasJSONParams = true + jsonParams[name] = value + } else { + form[name] = append(form[name], fmt.Sprintf("%v", value)) + } + } + } + + // Add query parameters to URL + if len(query) > 0 { + urlObj, err := url.Parse(targetURL) + if err != nil { + return nil, fmt.Errorf("invalid target URL: %w", err) + } + + q := urlObj.Query() + for key, values := range query { + for _, value := range values { + q.Add(key, value) + } + } + urlObj.RawQuery = q.Encode() + targetURL = urlObj.String() + } + + // Create request body + if hasJSONParams { + // Send JSON body + jsonData, err := json.Marshal(jsonParams) + if err != nil { + return nil, fmt.Errorf("failed to marshal JSON params: %w", err) + } + body = strings.NewReader(string(jsonData)) + contentType = "application/json" + } else if len(form) > 0 { + // Send form-encoded body + data := url.Values{} + for key, values := range form { + for _, value := range values { + data.Add(key, value) + } + } + body = strings.NewReader(data.Encode()) + contentType = "application/x-www-form-urlencoded" + } + } + + req, err := http.NewRequest(method, targetURL, body) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + if method != "GET" && contentType != "" { + req.Header.Set("Content-Type", contentType) + } + + return req, nil +} + +// Helper function to create multipart form for file uploads +func createMultipartForm(requestBody *bytes.Buffer, firmwareData []byte, filename string) string { + writer := multipart.NewWriter(requestBody) + + // Add file field + fileWriter, err := writer.CreateFormFile("firmware", filename) + if err != nil { + log.WithError(err).Error("Failed to create form file") + return "" + } + + _, err = fileWriter.Write(firmwareData) + if err != nil { + log.WithError(err).Error("Failed to write file data") + return "" + } + + err = writer.Close() + if err != nil { + log.WithError(err).Error("Failed to close multipart writer") + return "" + } + + return writer.FormDataContentType() +} diff --git a/pkg/config/config.go b/pkg/config/config.go new file mode 100644 index 0000000..fe4936c --- /dev/null +++ b/pkg/config/config.go @@ -0,0 +1,39 @@ +package config + +import ( + "os" + + log "github.com/sirupsen/logrus" +) + +// Config holds application configuration +type Config struct { + HTTPPort string + UDPPort string + LogLevel string +} + +// Load loads configuration from file or environment variables +func Load(configFile string) *Config { + cfg := &Config{ + HTTPPort: getEnvOrDefault("HTTP_PORT", "3001"), + UDPPort: getEnvOrDefault("UDP_PORT", "4210"), + LogLevel: getEnvOrDefault("LOG_LEVEL", "info"), + } + + // TODO: Load from config file if provided + if configFile != "" { + log.WithField("file", configFile).Info("Loading configuration from file") + // Implement file loading here if needed + } + + return cfg +} + +// getEnvOrDefault returns environment variable value or default +func getEnvOrDefault(key, defaultValue string) string { + if value := os.Getenv(key); value != "" { + return value + } + return defaultValue +}