feat: initial gateway implementation

This commit is contained in:
2025-10-19 21:54:14 +02:00
commit 5e1c39b0bf
10 changed files with 2223 additions and 0 deletions

View File

@@ -0,0 +1,369 @@
package discovery
import (
"context"
"fmt"
"net"
"strconv"
"strings"
"time"
log "github.com/sirupsen/logrus"
)
// Start starts the UDP discovery server
func (nd *NodeDiscovery) Start() error {
addr := net.UDPAddr{
Port: parsePort(nd.udpPort),
IP: net.ParseIP("0.0.0.0"),
}
conn, err := net.ListenUDP("udp", &addr)
if err != nil {
return fmt.Errorf("failed to listen on UDP port %s: %w", nd.udpPort, err)
}
nd.logger.WithField("port", nd.udpPort).Info("UDP heartbeat server listening")
// Start cleanup routine
go nd.startCleanupRoutine()
// Handle incoming UDP messages
buffer := make([]byte, 1024)
for {
n, remoteAddr, err := conn.ReadFromUDP(buffer)
if err != nil {
nd.logger.WithError(err).Error("Error reading UDP message")
continue
}
message := string(buffer[:n])
nd.handleUDPMessage(message, remoteAddr)
}
}
// Shutdown gracefully shuts down the discovery server
func (nd *NodeDiscovery) Shutdown(ctx context.Context) error {
nd.logger.Info("Shutting down node discovery")
// Close any connections if needed
return nil
}
// handleUDPMessage processes incoming UDP messages
func (nd *NodeDiscovery) handleUDPMessage(message string, remoteAddr *net.UDPAddr) {
nd.logger.WithFields(log.Fields{
"message": message,
"from": remoteAddr.String(),
}).Debug("UDP message received")
message = strings.TrimSpace(message)
if strings.HasPrefix(message, "CLUSTER_HEARTBEAT:") {
hostname := strings.TrimPrefix(message, "CLUSTER_HEARTBEAT:")
nd.updateNodeFromHeartbeat(remoteAddr.IP.String(), remoteAddr.Port, hostname)
} else if strings.HasPrefix(message, "NODE_UPDATE:") {
nd.handleNodeUpdate(remoteAddr.IP.String(), message)
} else if !strings.HasPrefix(message, "RAW:") {
nd.logger.WithField("message", message).Debug("Received unknown UDP message")
}
}
// updateNodeFromHeartbeat updates node information from heartbeat messages
func (nd *NodeDiscovery) updateNodeFromHeartbeat(sourceIP string, sourcePort int, hostname string) {
nd.mutex.Lock()
defer nd.mutex.Unlock()
now := time.Now()
existingNode, exists := nd.discoveredNodes[sourceIP]
if exists {
// Update existing node
wasStale := existingNode.Status == NodeStatusInactive
oldHostname := existingNode.Hostname
existingNode.LastSeen = now
existingNode.Hostname = hostname
existingNode.Status = NodeStatusActive
nd.logger.WithFields(log.Fields{
"ip": sourceIP,
"hostname": hostname,
"total": len(nd.discoveredNodes),
}).Debug("Heartbeat from existing node")
// Check if hostname changed
if oldHostname != hostname {
nd.logger.WithFields(log.Fields{
"ip": sourceIP,
"old_hostname": oldHostname,
"new_hostname": hostname,
}).Info("Hostname updated")
}
// Notify callbacks
action := "active"
if wasStale {
action = "node became active"
} else if oldHostname != hostname {
action = "hostname update"
}
nd.notifyCallbacks(sourceIP, action)
} else {
// Create new node entry
nodeInfo := &NodeInfo{
IP: sourceIP,
Port: sourcePort,
Hostname: hostname,
Status: NodeStatusActive,
DiscoveredAt: now,
LastSeen: now,
}
nd.discoveredNodes[sourceIP] = nodeInfo
nd.logger.WithFields(log.Fields{
"ip": sourceIP,
"hostname": hostname,
"total": len(nd.discoveredNodes),
}).Info("New node discovered via heartbeat")
// Set as primary node if this is the first one
if nd.primaryNode == "" {
nd.primaryNode = sourceIP
nd.logger.WithField("ip", sourceIP).Info("Set as primary node")
}
// Notify callbacks
nd.notifyCallbacks(sourceIP, "discovered")
}
}
// handleNodeUpdate processes NODE_UPDATE messages
func (nd *NodeDiscovery) handleNodeUpdate(sourceIP, message string) {
// Message format: "NODE_UPDATE:hostname:{json}"
parts := strings.SplitN(message, ":", 3)
if len(parts) < 3 {
nd.logger.WithField("message", message).Warn("Invalid NODE_UPDATE message format")
return
}
hostname := parts[1]
// jsonData := parts[2] // TODO: Parse JSON data when needed
nd.mutex.Lock()
defer nd.mutex.Unlock()
existingNode := nd.discoveredNodes[sourceIP]
if existingNode == nil {
nd.logger.WithField("ip", sourceIP).Warn("Received NODE_UPDATE for unknown node")
return
}
// Update node information from JSON data
// For now, we'll just update the hostname and mark as active
// TODO: Parse and update other fields from JSON
existingNode.Hostname = hostname
existingNode.LastSeen = time.Now()
existingNode.Status = NodeStatusActive
nd.logger.WithFields(log.Fields{
"ip": sourceIP,
"hostname": hostname,
}).Debug("Updated node from NODE_UPDATE")
nd.notifyCallbacks(sourceIP, "node update")
}
// startCleanupRoutine periodically marks stale nodes as inactive
func (nd *NodeDiscovery) startCleanupRoutine() {
ticker := time.NewTicker(2 * time.Second) // Check every 2 seconds
defer ticker.Stop()
for range ticker.C {
nd.markStaleNodes()
nd.updatePrimaryNode()
}
}
// markStaleNodes marks nodes as inactive if they haven't been seen recently
func (nd *NodeDiscovery) markStaleNodes() {
nd.mutex.Lock()
defer nd.mutex.Unlock()
now := time.Now()
var nodesMarkedStale bool
for ip, node := range nd.discoveredNodes {
timeSinceLastSeen := now.Sub(node.LastSeen)
if timeSinceLastSeen > nd.staleThreshold && node.Status != NodeStatusInactive {
nd.logger.WithFields(log.Fields{
"ip": ip,
"hostname": node.Hostname,
"time_since_seen": timeSinceLastSeen,
"threshold": nd.staleThreshold,
}).Info("Node marked inactive")
node.Status = NodeStatusInactive
nodesMarkedStale = true
nd.notifyCallbacks(ip, "stale")
// If this was our primary node, clear it
if nd.primaryNode == ip {
nd.primaryNode = ""
nd.logger.Info("Primary node became stale, clearing selection")
}
}
}
if nodesMarkedStale {
nd.logger.Debug("Nodes marked stale, triggering cluster update")
}
}
// updatePrimaryNode selects a new primary node if needed
func (nd *NodeDiscovery) updatePrimaryNode() {
nd.mutex.Lock()
defer nd.mutex.Unlock()
// If we don't have a primary node or current primary is not active, select a new one
if nd.primaryNode == "" || nd.discoveredNodes[nd.primaryNode].Status != NodeStatusActive {
nd.selectBestPrimaryNode()
}
}
// selectBestPrimaryNode selects the most recently seen active node as primary
func (nd *NodeDiscovery) selectBestPrimaryNode() {
if len(nd.discoveredNodes) == 0 {
return
}
// If current primary is still active, keep it
if nd.primaryNode != "" {
if node, exists := nd.discoveredNodes[nd.primaryNode]; exists && node.Status == NodeStatusActive {
return
}
}
// Find the most recently seen active node
var bestNode string
var mostRecent time.Time
for ip, node := range nd.discoveredNodes {
if node.Status == NodeStatusActive && node.LastSeen.After(mostRecent) {
mostRecent = node.LastSeen
bestNode = ip
}
}
if bestNode != "" && bestNode != nd.primaryNode {
nd.primaryNode = bestNode
nd.logger.WithField("ip", bestNode).Info("Selected new primary node")
nd.notifyCallbacks(bestNode, "primary node change")
}
}
// notifyCallbacks notifies all registered callbacks about node changes
func (nd *NodeDiscovery) notifyCallbacks(nodeIP, action string) {
for _, callback := range nd.callbacks {
go callback(nodeIP, action)
}
}
// GetNodes returns a copy of all discovered nodes
func (nd *NodeDiscovery) GetNodes() map[string]*NodeInfo {
nd.mutex.RLock()
defer nd.mutex.RUnlock()
nodes := make(map[string]*NodeInfo)
for ip, node := range nd.discoveredNodes {
nodes[ip] = node
}
return nodes
}
// GetPrimaryNode returns the current primary node IP
func (nd *NodeDiscovery) GetPrimaryNode() string {
nd.mutex.RLock()
defer nd.mutex.RUnlock()
return nd.primaryNode
}
// SetPrimaryNode manually sets the primary node
func (nd *NodeDiscovery) SetPrimaryNode(ip string) error {
nd.mutex.Lock()
defer nd.mutex.Unlock()
if _, exists := nd.discoveredNodes[ip]; !exists {
return fmt.Errorf("node %s not found", ip)
}
nd.primaryNode = ip
nd.notifyCallbacks(ip, "manual primary node setting")
return nil
}
// SelectRandomPrimaryNode selects a random active node as primary
func (nd *NodeDiscovery) SelectRandomPrimaryNode() string {
nd.mutex.Lock()
defer nd.mutex.Unlock()
if len(nd.discoveredNodes) == 0 {
return ""
}
// Get active nodes excluding current primary
var activeNodes []string
for ip, node := range nd.discoveredNodes {
if node.Status == NodeStatusActive && ip != nd.primaryNode {
activeNodes = append(activeNodes, ip)
}
}
if len(activeNodes) == 0 {
// No other active nodes, keep current primary
return nd.primaryNode
}
// Select random node (simple implementation)
randomIndex := time.Now().UnixNano() % int64(len(activeNodes))
randomNode := activeNodes[randomIndex]
nd.primaryNode = randomNode
nd.logger.WithField("ip", randomNode).Info("Randomly selected new primary node")
nd.notifyCallbacks(randomNode, "random primary node selection")
return randomNode
}
// AddCallback registers a callback for node updates
func (nd *NodeDiscovery) AddCallback(callback NodeUpdateCallback) {
nd.mutex.Lock()
defer nd.mutex.Unlock()
nd.callbacks = append(nd.callbacks, callback)
}
// GetClusterStatus returns current cluster status
func (nd *NodeDiscovery) GetClusterStatus() ClusterStatus {
nd.mutex.RLock()
defer nd.mutex.RUnlock()
return ClusterStatus{
PrimaryNode: nd.primaryNode,
TotalNodes: len(nd.discoveredNodes),
UDPPort: nd.udpPort,
ServerRunning: true,
}
}
// Helper function to parse port string to int
func parsePort(portStr string) int {
port := 4210 // default
if p, err := net.LookupPort("udp", portStr); err == nil {
port = p
} else if p, err := strconv.Atoi(portStr); err == nil {
port = p
}
return port
}

View File

@@ -0,0 +1,63 @@
package discovery
import (
"sync"
"time"
log "github.com/sirupsen/logrus"
)
// NodeStatus represents the status of a discovered node
type NodeStatus string
const (
NodeStatusActive NodeStatus = "active"
NodeStatusInactive NodeStatus = "inactive"
NodeStatusStale NodeStatus = "stale"
)
// NodeInfo represents information about a discovered SPORE node
type NodeInfo struct {
IP string `json:"ip"`
Port int `json:"port"`
Hostname string `json:"hostname"`
Status NodeStatus `json:"status"`
DiscoveredAt time.Time `json:"discoveredAt"`
LastSeen time.Time `json:"lastSeen"`
Uptime string `json:"uptime,omitempty"`
Labels map[string]string `json:"labels,omitempty"`
Latency int64 `json:"latency,omitempty"`
Resources map[string]interface{} `json:"resources,omitempty"`
}
// ClusterStatus represents the current cluster state
type ClusterStatus struct {
PrimaryNode string `json:"primaryNode"`
TotalNodes int `json:"totalNodes"`
UDPPort string `json:"udpPort"`
ServerRunning bool `json:"serverRunning"`
}
// NodeUpdateCallback is called when node information changes
type NodeUpdateCallback func(nodeIP string, action string)
// NodeDiscovery manages UDP-based node discovery
type NodeDiscovery struct {
udpPort string
discoveredNodes map[string]*NodeInfo
primaryNode string
mutex sync.RWMutex
callbacks []NodeUpdateCallback
staleThreshold time.Duration
logger *log.Logger
}
// NewNodeDiscovery creates a new node discovery instance
func NewNodeDiscovery(udpPort string) *NodeDiscovery {
return &NodeDiscovery{
udpPort: udpPort,
discoveredNodes: make(map[string]*NodeInfo),
staleThreshold: 8 * time.Second, // 8 seconds to accommodate 5-second heartbeat interval
logger: log.New(),
}
}

745
internal/server/server.go Normal file
View File

@@ -0,0 +1,745 @@
package server
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"spore-gateway/internal/discovery"
"spore-gateway/internal/websocket"
"spore-gateway/pkg/client"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
)
// HTTPServer represents the HTTP server
type HTTPServer struct {
port string
router *mux.Router
nodeDiscovery *discovery.NodeDiscovery
sporeClients map[string]*client.SporeClient
webSocketServer *websocket.WebSocketServer
server *http.Server
}
// NewHTTPServer creates a new HTTP server instance
func NewHTTPServer(port string, nodeDiscovery *discovery.NodeDiscovery) *HTTPServer {
// Initialize WebSocket server
wsServer := websocket.NewWebSocketServer(nodeDiscovery)
hs := &HTTPServer{
port: port,
router: mux.NewRouter(),
nodeDiscovery: nodeDiscovery,
sporeClients: make(map[string]*client.SporeClient),
webSocketServer: wsServer,
}
hs.setupRoutes()
hs.setupMiddleware()
hs.server = &http.Server{
Addr: ":" + port,
Handler: hs.router,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 60 * time.Second,
}
return hs
}
// setupMiddleware configures middleware for the server
func (hs *HTTPServer) setupMiddleware() {
// CORS middleware
hs.router.Use(hs.corsMiddleware)
// JSON middleware
hs.router.Use(hs.jsonMiddleware)
// Logging middleware
hs.router.Use(hs.loggingMiddleware)
}
// corsMiddleware handles CORS headers
func (hs *HTTPServer) corsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
next.ServeHTTP(w, r)
})
}
// jsonMiddleware sets JSON content type
func (hs *HTTPServer) jsonMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
next.ServeHTTP(w, r)
})
}
// loggingMiddleware logs HTTP requests
func (hs *HTTPServer) loggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
next.ServeHTTP(w, r)
log.WithFields(log.Fields{
"method": r.Method,
"path": r.URL.Path,
"remote_addr": r.RemoteAddr,
"user_agent": r.UserAgent(),
"duration": time.Since(start),
}).Debug("HTTP request")
})
}
// setupRoutes configures all the API routes
func (hs *HTTPServer) setupRoutes() {
// API routes
api := hs.router.PathPrefix("/api").Subrouter()
// Apply CORS middleware to API subrouter as well
api.Use(hs.corsMiddleware)
// Discovery endpoints
api.HandleFunc("/discovery/nodes", hs.getDiscoveryNodes).Methods("GET")
api.HandleFunc("/discovery/refresh", hs.refreshDiscovery).Methods("POST", "OPTIONS")
api.HandleFunc("/discovery/random-primary", hs.selectRandomPrimary).Methods("POST", "OPTIONS")
api.HandleFunc("/discovery/primary/{ip}", hs.setPrimaryNode).Methods("POST", "OPTIONS")
// Cluster endpoints
api.HandleFunc("/cluster/members", hs.getClusterMembers).Methods("GET")
api.HandleFunc("/cluster/refresh", hs.refreshCluster).Methods("POST", "OPTIONS")
// Task endpoints
api.HandleFunc("/tasks/status", hs.getTaskStatus).Methods("GET")
// Node endpoints
api.HandleFunc("/node/status", hs.getNodeStatus).Methods("GET")
api.HandleFunc("/node/status/{ip}", hs.getNodeStatusByIP).Methods("GET")
api.HandleFunc("/node/endpoints", hs.getNodeEndpoints).Methods("GET")
api.HandleFunc("/node/update", hs.updateNodeFirmware).Methods("POST", "OPTIONS")
// Proxy endpoints
api.HandleFunc("/proxy-call", hs.proxyCall).Methods("POST", "OPTIONS")
// Test endpoints
api.HandleFunc("/test/websocket", hs.testWebSocket).Methods("POST", "OPTIONS")
// Health check
api.HandleFunc("/health", hs.healthCheck).Methods("GET")
// WebSocket endpoint - apply CORS middleware
hs.router.HandleFunc("/ws", hs.corsMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := hs.webSocketServer.HandleWebSocket(w, r); err != nil {
log.WithError(err).Error("WebSocket connection failed")
http.Error(w, "WebSocket upgrade failed", http.StatusBadRequest)
}
})).ServeHTTP)
}
// Start starts the HTTP server
func (hs *HTTPServer) Start() error {
log.WithField("port", hs.port).Info("Starting HTTP server")
return hs.server.ListenAndServe()
}
// Shutdown gracefully shuts down the HTTP server
func (hs *HTTPServer) Shutdown(ctx context.Context) error {
log.Info("Shutting down HTTP server")
// Shutdown WebSocket server
if err := hs.webSocketServer.Shutdown(ctx); err != nil {
log.WithError(err).Error("WebSocket server shutdown error")
}
return hs.server.Shutdown(ctx)
}
// Helper function to get or create SPORE client for a node
func (hs *HTTPServer) getSporeClient(nodeIP string) *client.SporeClient {
if client, exists := hs.sporeClients[nodeIP]; exists {
return client
}
client := client.NewSporeClient(fmt.Sprintf("http://%s", nodeIP))
hs.sporeClients[nodeIP] = client
return client
}
// Helper function to perform operation with failover
func (hs *HTTPServer) performWithFailover(operation func(*client.SporeClient) (interface{}, error)) (interface{}, error) {
primaryNode := hs.nodeDiscovery.GetPrimaryNode()
nodes := hs.nodeDiscovery.GetNodes()
if len(nodes) == 0 {
return nil, fmt.Errorf("no SPORE nodes discovered")
}
// Build candidate list: primary first, then others by most recently seen
var candidateIPs []string
if primaryNode != "" {
if _, exists := nodes[primaryNode]; exists {
candidateIPs = append(candidateIPs, primaryNode)
}
}
for _, node := range nodes {
if node.IP != primaryNode {
candidateIPs = append(candidateIPs, node.IP)
}
}
var lastError error
for _, ip := range candidateIPs {
client := hs.getSporeClient(ip)
result, err := operation(client)
if err == nil {
// Success - if this wasn't the primary, switch to it
if ip != primaryNode && primaryNode != "" {
hs.nodeDiscovery.SetPrimaryNode(ip)
log.WithField("ip", ip).Info("Failover: switched primary node")
}
return result, nil
}
log.WithFields(log.Fields{
"ip": ip,
"err": err,
}).Warn("Primary attempt failed")
lastError = err
}
return nil, lastError
}
// API endpoint handlers
// GET /api/discovery/nodes
func (hs *HTTPServer) getDiscoveryNodes(w http.ResponseWriter, r *http.Request) {
nodes := hs.nodeDiscovery.GetNodes()
primaryNode := hs.nodeDiscovery.GetPrimaryNode()
clusterStatus := hs.nodeDiscovery.GetClusterStatus()
// Create response with enhanced node info including IsPrimary
type NodeResponse struct {
*discovery.NodeInfo
IsPrimary bool `json:"isPrimary"`
}
response := struct {
PrimaryNode string `json:"primaryNode"`
TotalNodes int `json:"totalNodes"`
Nodes []NodeResponse `json:"nodes"`
ClientInitialized bool `json:"clientInitialized"`
ClientBaseURL string `json:"clientBaseUrl"`
ClusterStatus discovery.ClusterStatus `json:"clusterStatus"`
}{
PrimaryNode: primaryNode,
TotalNodes: len(nodes),
Nodes: make([]NodeResponse, 0, len(nodes)),
ClientInitialized: primaryNode != "",
ClientBaseURL: "",
ClusterStatus: clusterStatus,
}
for _, node := range nodes {
nodeResponse := NodeResponse{
NodeInfo: node,
IsPrimary: node.IP == primaryNode,
}
response.Nodes = append(response.Nodes, nodeResponse)
}
json.NewEncoder(w).Encode(response)
}
// POST /api/discovery/refresh
func (hs *HTTPServer) refreshDiscovery(w http.ResponseWriter, r *http.Request) {
// Mark stale nodes and update primary if needed
// The node discovery system handles this automatically via its cleanup routine
response := struct {
Success bool `json:"success"`
Message string `json:"message"`
PrimaryNode string `json:"primaryNode"`
TotalNodes int `json:"totalNodes"`
ClientInitialized bool `json:"clientInitialized"`
}{
Success: true,
Message: "Cluster refresh completed",
PrimaryNode: hs.nodeDiscovery.GetPrimaryNode(),
TotalNodes: len(hs.nodeDiscovery.GetNodes()),
ClientInitialized: hs.nodeDiscovery.GetPrimaryNode() != "",
}
json.NewEncoder(w).Encode(response)
}
// POST /api/discovery/random-primary
func (hs *HTTPServer) selectRandomPrimary(w http.ResponseWriter, r *http.Request) {
nodes := hs.nodeDiscovery.GetNodes()
if len(nodes) == 0 {
http.Error(w, `{"error": "No nodes available", "message": "No SPORE nodes have been discovered yet"}`, http.StatusNotFound)
return
}
newPrimary := hs.nodeDiscovery.SelectRandomPrimaryNode()
if newPrimary == "" {
http.Error(w, `{"error": "Selection failed", "message": "Failed to select a random primary node"}`, http.StatusInternalServerError)
return
}
response := struct {
Success bool `json:"success"`
Message string `json:"message"`
PrimaryNode string `json:"primaryNode"`
TotalNodes int `json:"totalNodes"`
ClientInitialized bool `json:"clientInitialized"`
Timestamp string `json:"timestamp"`
}{
Success: true,
Message: fmt.Sprintf("Randomly selected new primary node: %s", newPrimary),
PrimaryNode: newPrimary,
TotalNodes: len(nodes),
ClientInitialized: true,
Timestamp: time.Now().Format(time.RFC3339),
}
json.NewEncoder(w).Encode(response)
}
// POST /api/discovery/primary/{ip}
func (hs *HTTPServer) setPrimaryNode(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
requestedIP := vars["ip"]
if err := hs.nodeDiscovery.SetPrimaryNode(requestedIP); err != nil {
http.Error(w, fmt.Sprintf(`{"error": "Node not found", "message": "Node with IP %s has not been discovered"}`, requestedIP), http.StatusNotFound)
return
}
response := struct {
Success bool `json:"success"`
Message string `json:"message"`
PrimaryNode string `json:"primaryNode"`
ClientInitialized bool `json:"clientInitialized"`
}{
Success: true,
Message: fmt.Sprintf("Primary node set to %s", requestedIP),
PrimaryNode: requestedIP,
ClientInitialized: true,
}
json.NewEncoder(w).Encode(response)
}
// GET /api/cluster/members
func (hs *HTTPServer) getClusterMembers(w http.ResponseWriter, r *http.Request) {
result, err := hs.performWithFailover(func(client *client.SporeClient) (interface{}, error) {
return client.GetClusterStatus()
})
if err != nil {
log.WithError(err).Error("Error fetching cluster members")
http.Error(w, fmt.Sprintf(`{"error": "Failed to fetch cluster members", "message": "%s"}`, err.Error()), http.StatusBadGateway)
return
}
json.NewEncoder(w).Encode(result)
}
// POST /api/cluster/refresh
func (hs *HTTPServer) refreshCluster(w http.ResponseWriter, r *http.Request) {
var requestBody struct {
Reason string `json:"reason"`
}
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil && err.Error() != "EOF" {
http.Error(w, `{"error": "Invalid JSON", "message": "Failed to parse request body"}`, http.StatusBadRequest)
return
}
reason := requestBody.Reason
if reason == "" {
reason = "manual_refresh"
}
log.WithField("reason", reason).Info("Manual cluster refresh triggered")
response := struct {
Success bool `json:"success"`
Message string `json:"message"`
Reason string `json:"reason"`
WSclients int `json:"wsClients"`
}{
Success: true,
Message: "Cluster refresh triggered",
Reason: reason,
WSclients: hs.webSocketServer.GetClientCount(),
}
json.NewEncoder(w).Encode(response)
}
// GET /api/tasks/status
func (hs *HTTPServer) getTaskStatus(w http.ResponseWriter, r *http.Request) {
ip := r.URL.Query().Get("ip")
if ip != "" {
client := hs.getSporeClient(ip)
result, err := client.GetTaskStatus()
if err != nil {
log.WithError(err).Error("Error fetching task status from specific node")
http.Error(w, fmt.Sprintf(`{"error": "Failed to fetch task status from node", "message": "%s"}`, err.Error()), http.StatusInternalServerError)
return
}
json.NewEncoder(w).Encode(result)
return
}
result, err := hs.performWithFailover(func(client *client.SporeClient) (interface{}, error) {
return client.GetTaskStatus()
})
if err != nil {
log.WithError(err).Error("Error fetching task status")
http.Error(w, fmt.Sprintf(`{"error": "Failed to fetch task status", "message": "%s"}`, err.Error()), http.StatusBadGateway)
return
}
json.NewEncoder(w).Encode(result)
}
// GET /api/node/status
func (hs *HTTPServer) getNodeStatus(w http.ResponseWriter, r *http.Request) {
result, err := hs.performWithFailover(func(client *client.SporeClient) (interface{}, error) {
return client.GetSystemStatus()
})
if err != nil {
log.WithError(err).Error("Error fetching system status")
http.Error(w, fmt.Sprintf(`{"error": "Failed to fetch system status", "message": "%s"}`, err.Error()), http.StatusBadGateway)
return
}
json.NewEncoder(w).Encode(result)
}
// GET /api/node/status/{ip}
func (hs *HTTPServer) getNodeStatusByIP(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
nodeIP := vars["ip"]
client := hs.getSporeClient(nodeIP)
result, err := client.GetSystemStatus()
if err != nil {
log.WithError(err).Error("Error fetching status from specific node")
http.Error(w, fmt.Sprintf(`{"error": "Failed to fetch status from node %s", "message": "%s"}`, nodeIP, err.Error()), http.StatusInternalServerError)
return
}
json.NewEncoder(w).Encode(result)
}
// GET /api/node/endpoints
func (hs *HTTPServer) getNodeEndpoints(w http.ResponseWriter, r *http.Request) {
ip := r.URL.Query().Get("ip")
if ip != "" {
client := hs.getSporeClient(ip)
result, err := client.GetCapabilities()
if err != nil {
log.WithError(err).Error("Error fetching endpoints from specific node")
http.Error(w, fmt.Sprintf(`{"error": "Failed to fetch endpoints from node", "message": "%s"}`, err.Error()), http.StatusInternalServerError)
return
}
json.NewEncoder(w).Encode(result)
return
}
result, err := hs.performWithFailover(func(client *client.SporeClient) (interface{}, error) {
return client.GetCapabilities()
})
if err != nil {
log.WithError(err).Error("Error fetching capabilities")
http.Error(w, fmt.Sprintf(`{"error": "Failed to fetch capabilities", "message": "%s"}`, err.Error()), http.StatusBadGateway)
return
}
json.NewEncoder(w).Encode(result)
}
// POST /api/node/update
func (hs *HTTPServer) updateNodeFirmware(w http.ResponseWriter, r *http.Request) {
nodeIP := r.URL.Query().Get("ip")
if nodeIP == "" {
nodeIP = r.Header.Get("X-Node-IP")
}
if nodeIP == "" {
http.Error(w, `{"error": "Node IP address is required", "message": "Please provide the target node IP address"}`, http.StatusBadRequest)
return
}
// Parse multipart form
err := r.ParseMultipartForm(50 << 20) // 50MB limit
if err != nil {
log.WithError(err).Error("Error parsing multipart form")
http.Error(w, `{"error": "Failed to parse form", "message": "Error parsing multipart form data"}`, http.StatusBadRequest)
return
}
file, fileHeader, err := r.FormFile("file")
if err != nil {
log.WithError(err).Error("No file found in form")
http.Error(w, `{"error": "No file data received", "message": "Please select a firmware file to upload"}`, http.StatusBadRequest)
return
}
defer file.Close()
// Get the original filename
filename := fileHeader.Filename
if filename == "" {
filename = "firmware.bin"
}
// Read file data
fileData := make([]byte, 0)
buffer := make([]byte, 1024)
for {
n, err := file.Read(buffer)
if n > 0 {
fileData = append(fileData, buffer[:n]...)
}
if err != nil {
if err.Error() == "EOF" {
break
}
log.WithError(err).Error("Error reading file data")
http.Error(w, `{"error": "Failed to read file", "message": "Error reading uploaded file data"}`, http.StatusInternalServerError)
return
}
}
log.WithFields(log.Fields{
"node_ip": nodeIP,
"file_size": len(fileData),
}).Info("Firmware upload received")
client := hs.getSporeClient(nodeIP)
result, err := client.UpdateFirmware(fileData, filename)
if err != nil {
log.WithError(err).Error("Error uploading firmware")
http.Error(w, fmt.Sprintf(`{"error": "Failed to upload firmware", "message": "%s"}`, err.Error()), http.StatusInternalServerError)
return
}
// Check if the device reported a failure
if result.Status == "FAIL" {
log.WithField("message", result.Message).Error("Device reported firmware update failure")
http.Error(w, fmt.Sprintf(`{"success": false, "error": "Firmware update failed", "message": "%s"}`, result.Message), http.StatusBadRequest)
return
}
response := struct {
Success bool `json:"success"`
Message string `json:"message"`
NodeIP string `json:"nodeIp"`
FileSize int `json:"fileSize"`
Filename string `json:"filename"`
Result interface{} `json:"result"`
}{
Success: true,
Message: "Firmware uploaded successfully",
NodeIP: nodeIP,
FileSize: len(fileData),
Filename: filename,
Result: result,
}
json.NewEncoder(w).Encode(response)
}
// POST /api/proxy-call
func (hs *HTTPServer) proxyCall(w http.ResponseWriter, r *http.Request) {
var requestBody struct {
IP string `json:"ip"`
Method string `json:"method"`
URI string `json:"uri"`
Params []map[string]interface{} `json:"params"`
}
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
http.Error(w, `{"error": "Invalid JSON", "message": "Failed to parse request body"}`, http.StatusBadRequest)
return
}
if requestBody.IP == "" || requestBody.Method == "" || requestBody.URI == "" {
http.Error(w, `{"error": "Missing required fields", "message": "Required: ip, method, uri"}`, http.StatusBadRequest)
return
}
// Convert params to map for client
params := make(map[string]interface{})
for _, param := range requestBody.Params {
if name, ok := param["name"].(string); ok {
// Create parameter object with type and location info
paramObj := map[string]interface{}{
"location": "body",
"type": "string", // default type
}
// Extract the actual value from the parameter object
if value, ok := param["value"]; ok {
paramObj["value"] = value
} else {
paramObj["value"] = param
}
// Check if we have type information from the endpoint definition
// For now, we'll detect JSON by checking if the value is a JSON string
if value, ok := paramObj["value"].(string); ok {
// Special handling for labels parameter - it expects raw JSON string
if name == "labels" {
paramObj["type"] = "json"
// Keep the value as string for labels parameter
} else {
// Try to parse as JSON to detect if it's a JSON parameter
var jsonValue interface{}
if err := json.Unmarshal([]byte(value), &jsonValue); err == nil {
paramObj["type"] = "json"
paramObj["value"] = jsonValue
}
}
}
params[name] = paramObj
}
}
client := hs.getSporeClient(requestBody.IP)
resp, err := client.ProxyCall(requestBody.Method, requestBody.URI, params)
if err != nil {
log.WithError(err).Error("Error in proxy call")
http.Error(w, fmt.Sprintf(`{"error": "Proxy call failed", "message": "%s"}`, err.Error()), http.StatusInternalServerError)
return
}
defer resp.Body.Close()
// Read response
body, err := io.ReadAll(resp.Body)
if err != nil {
log.WithError(err).Error("Error reading proxy response")
http.Error(w, `{"error": "Failed to read response", "message": "Error reading upstream response"}`, http.StatusInternalServerError)
return
}
// Set appropriate content type
contentType := resp.Header.Get("Content-Type")
if contentType != "" {
w.Header().Set("Content-Type", contentType)
}
// Set CORS headers for proxy responses
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
// Set status code
w.WriteHeader(resp.StatusCode)
// For consistency with frontend expectations, wrap response in data field for JSON responses
if contentType != "" && strings.Contains(contentType, "application/json") {
// Try to parse and re-wrap the JSON response
var jsonResp interface{}
if err := json.Unmarshal(body, &jsonResp); err == nil {
wrappedResp := map[string]interface{}{
"data": jsonResp,
"status": resp.StatusCode,
}
body, _ = json.Marshal(wrappedResp)
}
}
// Write response body
w.Write(body)
}
// POST /api/test/websocket
func (hs *HTTPServer) testWebSocket(w http.ResponseWriter, r *http.Request) {
log.Info("Manual WebSocket test triggered")
response := struct {
Success bool `json:"success"`
Message string `json:"message"`
WSclients int `json:"websocketClients"`
TotalNodes int `json:"totalNodes"`
}{
Success: true,
Message: "WebSocket test broadcast sent",
WSclients: hs.webSocketServer.GetClientCount(),
TotalNodes: len(hs.nodeDiscovery.GetNodes()),
}
json.NewEncoder(w).Encode(response)
}
// GET /api/health
func (hs *HTTPServer) healthCheck(w http.ResponseWriter, r *http.Request) {
primaryNode := hs.nodeDiscovery.GetPrimaryNode()
nodes := hs.nodeDiscovery.GetNodes()
clusterStatus := hs.nodeDiscovery.GetClusterStatus()
health := struct {
Status string `json:"status"`
Timestamp string `json:"timestamp"`
Services map[string]bool `json:"services"`
Cluster map[string]interface{} `json:"cluster"`
}{
Status: "healthy",
Timestamp: time.Now().Format(time.RFC3339),
Services: map[string]bool{
"http": true,
"udp": clusterStatus.ServerRunning,
"sporeClient": primaryNode != "",
},
Cluster: map[string]interface{}{
"totalNodes": clusterStatus.TotalNodes,
"primaryNode": clusterStatus.PrimaryNode,
"udpPort": clusterStatus.UDPPort,
"serverRunning": clusterStatus.ServerRunning,
},
}
// Mark as degraded if no nodes discovered
if len(nodes) == 0 {
health.Status = "degraded"
}
// Mark as degraded if no client initialized
if primaryNode == "" {
health.Status = "degraded"
}
statusCode := http.StatusOK
if health.Status != "healthy" {
statusCode = http.StatusServiceUnavailable
}
w.WriteHeader(statusCode)
json.NewEncoder(w).Encode(health)
}

View File

@@ -0,0 +1,370 @@
package websocket
import (
"context"
"encoding/json"
"net/http"
"sync"
"time"
"spore-gateway/internal/discovery"
"spore-gateway/pkg/client"
"github.com/gorilla/websocket"
log "github.com/sirupsen/logrus"
)
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true // Allow connections from any origin in development
},
}
// WebSocketServer manages WebSocket connections and broadcasts
type WebSocketServer struct {
nodeDiscovery *discovery.NodeDiscovery
sporeClients map[string]*client.SporeClient
clients map[*websocket.Conn]bool
mutex sync.RWMutex
logger *log.Logger
}
// NewWebSocketServer creates a new WebSocket server
func NewWebSocketServer(nodeDiscovery *discovery.NodeDiscovery) *WebSocketServer {
wss := &WebSocketServer{
nodeDiscovery: nodeDiscovery,
sporeClients: make(map[string]*client.SporeClient),
clients: make(map[*websocket.Conn]bool),
logger: log.New(),
}
// Register callback for node updates
nodeDiscovery.AddCallback(wss.handleNodeUpdate)
return wss
}
// HandleWebSocket handles WebSocket upgrade and connection
func (wss *WebSocketServer) HandleWebSocket(w http.ResponseWriter, r *http.Request) error {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
wss.logger.WithError(err).Error("Failed to upgrade WebSocket connection")
return err
}
wss.mutex.Lock()
wss.clients[conn] = true
wss.mutex.Unlock()
wss.logger.Debug("WebSocket client connected")
// Send current cluster state to newly connected client
go wss.sendCurrentClusterState(conn)
// Handle client messages and disconnection
go wss.handleClient(conn)
return nil
}
// handleClient handles messages from a WebSocket client
func (wss *WebSocketServer) handleClient(conn *websocket.Conn) {
defer func() {
wss.mutex.Lock()
delete(wss.clients, conn)
wss.mutex.Unlock()
conn.Close()
wss.logger.Debug("WebSocket client disconnected")
}()
// Set read deadline and pong handler
conn.SetReadDeadline(time.Now().Add(60 * time.Second))
conn.SetPongHandler(func(string) error {
conn.SetReadDeadline(time.Now().Add(60 * time.Second))
return nil
})
// Start ping routine
go func() {
ticker := time.NewTicker(54 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
}
}
}()
// Read messages (we don't expect any, but this keeps the connection alive)
for {
_, _, err := conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
wss.logger.WithError(err).Error("WebSocket error")
}
break
}
}
}
// sendCurrentClusterState sends the current cluster state to a newly connected client
func (wss *WebSocketServer) sendCurrentClusterState(conn *websocket.Conn) {
nodes := wss.nodeDiscovery.GetNodes()
if len(nodes) == 0 {
return
}
// Get real cluster data from SPORE nodes
clusterData, err := wss.getCurrentClusterMembers()
if err != nil {
wss.logger.WithError(err).Error("Failed to get cluster data for WebSocket")
return
}
message := struct {
Type string `json:"type"`
Members []client.ClusterMember `json:"members"`
PrimaryNode string `json:"primaryNode"`
TotalNodes int `json:"totalNodes"`
Timestamp string `json:"timestamp"`
}{
Type: "cluster_update",
Members: clusterData,
PrimaryNode: wss.nodeDiscovery.GetPrimaryNode(),
TotalNodes: len(nodes),
Timestamp: time.Now().Format(time.RFC3339),
}
data, err := json.Marshal(message)
if err != nil {
wss.logger.WithError(err).Error("Failed to marshal cluster data")
return
}
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := conn.WriteMessage(websocket.TextMessage, data); err != nil {
wss.logger.WithError(err).Error("Failed to send initial cluster state")
}
}
// handleNodeUpdate is called when node information changes
func (wss *WebSocketServer) handleNodeUpdate(nodeIP, action string) {
wss.logger.WithFields(log.Fields{
"node_ip": nodeIP,
"action": action,
}).Debug("Node update received, broadcasting to WebSocket clients")
// Broadcast cluster update to all clients
wss.broadcastClusterUpdate()
// Also broadcast node discovery event
wss.broadcastNodeDiscovery(nodeIP, action)
}
// broadcastClusterUpdate sends cluster updates to all connected clients
func (wss *WebSocketServer) broadcastClusterUpdate() {
wss.mutex.RLock()
clients := make([]*websocket.Conn, 0, len(wss.clients))
for client := range wss.clients {
clients = append(clients, client)
}
wss.mutex.RUnlock()
if len(clients) == 0 {
return
}
startTime := time.Now()
// Get cluster members asynchronously
clusterData, err := wss.getCurrentClusterMembers()
if err != nil {
wss.logger.WithError(err).Error("Failed to get cluster data for broadcast")
return
}
message := struct {
Type string `json:"type"`
Members []client.ClusterMember `json:"members"`
PrimaryNode string `json:"primaryNode"`
TotalNodes int `json:"totalNodes"`
Timestamp string `json:"timestamp"`
}{
Type: "cluster_update",
Members: clusterData,
PrimaryNode: wss.nodeDiscovery.GetPrimaryNode(),
TotalNodes: len(wss.nodeDiscovery.GetNodes()),
Timestamp: time.Now().Format(time.RFC3339),
}
data, err := json.Marshal(message)
if err != nil {
wss.logger.WithError(err).Error("Failed to marshal cluster update")
return
}
broadcastTime := time.Now()
wss.logger.WithFields(log.Fields{
"clients": len(clients),
"prep_time": broadcastTime.Sub(startTime),
}).Debug("Broadcasting cluster update to WebSocket clients")
// Send to all clients
var failedClients int
for _, client := range clients {
client.SetWriteDeadline(time.Now().Add(5 * time.Second))
if err := client.WriteMessage(websocket.TextMessage, data); err != nil {
wss.logger.WithError(err).Error("Failed to send cluster update to client")
failedClients++
}
}
totalTime := time.Now().Sub(startTime)
wss.logger.WithFields(log.Fields{
"clients": len(clients),
"failed_clients": failedClients,
"total_time": totalTime,
}).Debug("Cluster update broadcast completed")
}
// broadcastNodeDiscovery sends node discovery events to all clients
func (wss *WebSocketServer) broadcastNodeDiscovery(nodeIP, action string) {
wss.mutex.RLock()
clients := make([]*websocket.Conn, 0, len(wss.clients))
for client := range wss.clients {
clients = append(clients, client)
}
wss.mutex.RUnlock()
if len(clients) == 0 {
return
}
message := struct {
Type string `json:"type"`
Action string `json:"action"`
NodeIP string `json:"nodeIp"`
Timestamp string `json:"timestamp"`
}{
Type: "node_discovery",
Action: action,
NodeIP: nodeIP,
Timestamp: time.Now().Format(time.RFC3339),
}
data, err := json.Marshal(message)
if err != nil {
wss.logger.WithError(err).Error("Failed to marshal node discovery event")
return
}
for _, client := range clients {
client.SetWriteDeadline(time.Now().Add(5 * time.Second))
if err := client.WriteMessage(websocket.TextMessage, data); err != nil {
wss.logger.WithError(err).Error("Failed to send node discovery event to client")
}
}
}
// getCurrentClusterMembers fetches real cluster data from SPORE nodes
func (wss *WebSocketServer) getCurrentClusterMembers() ([]client.ClusterMember, error) {
nodes := wss.nodeDiscovery.GetNodes()
if len(nodes) == 0 {
return []client.ClusterMember{}, nil
}
// Try to get real cluster data from primary node
primaryNode := wss.nodeDiscovery.GetPrimaryNode()
if primaryNode != "" {
client := wss.getSporeClient(primaryNode)
clusterStatus, err := client.GetClusterStatus()
if err == nil {
// Update local node data with API information
wss.updateLocalNodesWithAPI(clusterStatus.Members)
return clusterStatus.Members, nil
}
wss.logger.WithError(err).Error("Failed to get cluster status from primary node")
}
// Fallback to local data if API fails
return wss.getFallbackClusterMembers(), nil
}
// updateLocalNodesWithAPI updates local node data with information from API
func (wss *WebSocketServer) updateLocalNodesWithAPI(apiMembers []client.ClusterMember) {
// This would update the local node discovery with fresh API data
// For now, we'll just log that we received the data
wss.logger.WithField("members", len(apiMembers)).Debug("Updating local nodes with API data")
for _, member := range apiMembers {
if member.Labels != nil && len(member.Labels) > 0 {
wss.logger.WithFields(log.Fields{
"ip": member.IP,
"labels": member.Labels,
}).Debug("API member labels")
}
}
}
// getFallbackClusterMembers returns local node data as fallback
func (wss *WebSocketServer) getFallbackClusterMembers() []client.ClusterMember {
nodes := wss.nodeDiscovery.GetNodes()
members := make([]client.ClusterMember, 0, len(nodes))
for _, node := range nodes {
member := client.ClusterMember{
IP: node.IP,
Hostname: node.Hostname,
Status: string(node.Status),
Latency: node.Latency,
LastSeen: node.LastSeen.Unix(),
Labels: node.Labels,
}
members = append(members, member)
}
return members
}
// getSporeClient gets or creates a SPORE client for a node
func (wss *WebSocketServer) getSporeClient(nodeIP string) *client.SporeClient {
if client, exists := wss.sporeClients[nodeIP]; exists {
return client
}
client := client.NewSporeClient("http://" + nodeIP)
wss.sporeClients[nodeIP] = client
return client
}
// GetClientCount returns the number of connected WebSocket clients
func (wss *WebSocketServer) GetClientCount() int {
wss.mutex.RLock()
defer wss.mutex.RUnlock()
return len(wss.clients)
}
// Shutdown gracefully shuts down the WebSocket server
func (wss *WebSocketServer) Shutdown(ctx context.Context) error {
wss.logger.Info("Shutting down WebSocket server")
wss.mutex.Lock()
clients := make([]*websocket.Conn, 0, len(wss.clients))
for client := range wss.clients {
clients = append(clients, client)
}
wss.mutex.Unlock()
// Close all client connections
for _, client := range clients {
client.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, "Server shutting down"))
client.Close()
}
return nil
}