Files
spore-gateway/internal/server/server.go

740 lines
22 KiB
Go

package server
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"spore-gateway/internal/discovery"
"spore-gateway/internal/websocket"
"spore-gateway/pkg/client"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
)
// HTTPServer represents the HTTP server
type HTTPServer struct {
port string
router *mux.Router
nodeDiscovery *discovery.NodeDiscovery
sporeClients map[string]*client.SporeClient
webSocketServer *websocket.WebSocketServer
server *http.Server
}
// NewHTTPServer creates a new HTTP server instance
func NewHTTPServer(port string, nodeDiscovery *discovery.NodeDiscovery) *HTTPServer {
// Initialize WebSocket server
wsServer := websocket.NewWebSocketServer(nodeDiscovery)
hs := &HTTPServer{
port: port,
router: mux.NewRouter(),
nodeDiscovery: nodeDiscovery,
sporeClients: make(map[string]*client.SporeClient),
webSocketServer: wsServer,
}
hs.setupRoutes()
hs.setupMiddleware()
hs.server = &http.Server{
Addr: ":" + port,
Handler: hs.router,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 60 * time.Second,
}
return hs
}
// setupMiddleware configures middleware for the server
func (hs *HTTPServer) setupMiddleware() {
// CORS middleware
hs.router.Use(hs.corsMiddleware)
// JSON middleware
hs.router.Use(hs.jsonMiddleware)
// Logging middleware
hs.router.Use(hs.loggingMiddleware)
}
// corsMiddleware handles CORS headers
func (hs *HTTPServer) corsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
next.ServeHTTP(w, r)
})
}
// jsonMiddleware sets JSON content type
func (hs *HTTPServer) jsonMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
next.ServeHTTP(w, r)
})
}
// loggingMiddleware logs HTTP requests
func (hs *HTTPServer) loggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
next.ServeHTTP(w, r)
log.WithFields(log.Fields{
"method": r.Method,
"path": r.URL.Path,
"remote_addr": r.RemoteAddr,
"user_agent": r.UserAgent(),
"duration": time.Since(start),
}).Debug("HTTP request")
})
}
// setupRoutes configures all the API routes
func (hs *HTTPServer) setupRoutes() {
// API routes
api := hs.router.PathPrefix("/api").Subrouter()
// Apply CORS middleware to API subrouter as well
api.Use(hs.corsMiddleware)
// Discovery endpoints
api.HandleFunc("/discovery/nodes", hs.getDiscoveryNodes).Methods("GET")
api.HandleFunc("/discovery/refresh", hs.refreshDiscovery).Methods("POST", "OPTIONS")
api.HandleFunc("/discovery/random-primary", hs.selectRandomPrimary).Methods("POST", "OPTIONS")
api.HandleFunc("/discovery/primary/{ip}", hs.setPrimaryNode).Methods("POST", "OPTIONS")
// Cluster endpoints
api.HandleFunc("/cluster/members", hs.getClusterMembers).Methods("GET")
api.HandleFunc("/cluster/refresh", hs.refreshCluster).Methods("POST", "OPTIONS")
// Task endpoints
api.HandleFunc("/tasks/status", hs.getTaskStatus).Methods("GET")
// Node endpoints
api.HandleFunc("/node/status", hs.getNodeStatus).Methods("GET")
api.HandleFunc("/node/status/{ip}", hs.getNodeStatusByIP).Methods("GET")
api.HandleFunc("/node/endpoints", hs.getNodeEndpoints).Methods("GET")
api.HandleFunc("/node/update", hs.updateNodeFirmware).Methods("POST", "OPTIONS")
// Proxy endpoints
api.HandleFunc("/proxy-call", hs.proxyCall).Methods("POST", "OPTIONS")
// Test endpoints
api.HandleFunc("/test/websocket", hs.testWebSocket).Methods("POST", "OPTIONS")
// Health check
api.HandleFunc("/health", hs.healthCheck).Methods("GET")
// WebSocket endpoint - apply CORS middleware
hs.router.HandleFunc("/ws", hs.corsMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := hs.webSocketServer.HandleWebSocket(w, r); err != nil {
log.WithError(err).Error("WebSocket connection failed")
http.Error(w, "WebSocket upgrade failed", http.StatusBadRequest)
}
})).ServeHTTP)
}
// Start starts the HTTP server
func (hs *HTTPServer) Start() error {
log.WithField("port", hs.port).Info("Starting HTTP server")
return hs.server.ListenAndServe()
}
// Shutdown gracefully shuts down the HTTP server
func (hs *HTTPServer) Shutdown(ctx context.Context) error {
log.Info("Shutting down HTTP server")
// Shutdown WebSocket server
if err := hs.webSocketServer.Shutdown(ctx); err != nil {
log.WithError(err).Error("WebSocket server shutdown error")
}
return hs.server.Shutdown(ctx)
}
// Helper function to get or create SPORE client for a node
func (hs *HTTPServer) getSporeClient(nodeIP string) *client.SporeClient {
if client, exists := hs.sporeClients[nodeIP]; exists {
return client
}
client := client.NewSporeClient(fmt.Sprintf("http://%s", nodeIP))
hs.sporeClients[nodeIP] = client
return client
}
// Helper function to perform operation with failover
func (hs *HTTPServer) performWithFailover(operation func(*client.SporeClient) (interface{}, error)) (interface{}, error) {
primaryNode := hs.nodeDiscovery.GetPrimaryNode()
nodes := hs.nodeDiscovery.GetNodes()
if len(nodes) == 0 {
return nil, fmt.Errorf("no SPORE nodes discovered")
}
// Build candidate list: primary first, then others by most recently seen
var candidateIPs []string
if primaryNode != "" {
if _, exists := nodes[primaryNode]; exists {
candidateIPs = append(candidateIPs, primaryNode)
}
}
for _, node := range nodes {
if node.IP != primaryNode {
candidateIPs = append(candidateIPs, node.IP)
}
}
var lastError error
for _, ip := range candidateIPs {
client := hs.getSporeClient(ip)
result, err := operation(client)
if err == nil {
// Success - if this wasn't the primary, switch to it
if ip != primaryNode && primaryNode != "" {
hs.nodeDiscovery.SetPrimaryNode(ip)
log.WithField("ip", ip).Info("Failover: switched primary node")
}
return result, nil
}
log.WithFields(log.Fields{
"ip": ip,
"err": err,
}).Warn("Primary attempt failed")
lastError = err
}
return nil, lastError
}
// API endpoint handlers
// GET /api/discovery/nodes
func (hs *HTTPServer) getDiscoveryNodes(w http.ResponseWriter, r *http.Request) {
nodes := hs.nodeDiscovery.GetNodes()
primaryNode := hs.nodeDiscovery.GetPrimaryNode()
clusterStatus := hs.nodeDiscovery.GetClusterStatus()
// Create response with enhanced node info including IsPrimary
type NodeResponse struct {
*discovery.NodeInfo
IsPrimary bool `json:"isPrimary"`
}
response := struct {
PrimaryNode string `json:"primaryNode"`
TotalNodes int `json:"totalNodes"`
Nodes []NodeResponse `json:"nodes"`
ClientInitialized bool `json:"clientInitialized"`
ClientBaseURL string `json:"clientBaseUrl"`
ClusterStatus discovery.ClusterStatus `json:"clusterStatus"`
}{
PrimaryNode: primaryNode,
TotalNodes: len(nodes),
Nodes: make([]NodeResponse, 0, len(nodes)),
ClientInitialized: primaryNode != "",
ClientBaseURL: "",
ClusterStatus: clusterStatus,
}
for _, node := range nodes {
nodeResponse := NodeResponse{
NodeInfo: node,
IsPrimary: node.IP == primaryNode,
}
response.Nodes = append(response.Nodes, nodeResponse)
}
json.NewEncoder(w).Encode(response)
}
// POST /api/discovery/refresh
func (hs *HTTPServer) refreshDiscovery(w http.ResponseWriter, r *http.Request) {
// Mark stale nodes and update primary if needed
// The node discovery system handles this automatically via its cleanup routine
response := struct {
Success bool `json:"success"`
Message string `json:"message"`
PrimaryNode string `json:"primaryNode"`
TotalNodes int `json:"totalNodes"`
ClientInitialized bool `json:"clientInitialized"`
}{
Success: true,
Message: "Cluster refresh completed",
PrimaryNode: hs.nodeDiscovery.GetPrimaryNode(),
TotalNodes: len(hs.nodeDiscovery.GetNodes()),
ClientInitialized: hs.nodeDiscovery.GetPrimaryNode() != "",
}
json.NewEncoder(w).Encode(response)
}
// POST /api/discovery/random-primary
func (hs *HTTPServer) selectRandomPrimary(w http.ResponseWriter, r *http.Request) {
nodes := hs.nodeDiscovery.GetNodes()
if len(nodes) == 0 {
http.Error(w, `{"error": "No nodes available", "message": "No SPORE nodes have been discovered yet"}`, http.StatusNotFound)
return
}
newPrimary := hs.nodeDiscovery.SelectRandomPrimaryNode()
if newPrimary == "" {
http.Error(w, `{"error": "Selection failed", "message": "Failed to select a random primary node"}`, http.StatusInternalServerError)
return
}
response := struct {
Success bool `json:"success"`
Message string `json:"message"`
PrimaryNode string `json:"primaryNode"`
TotalNodes int `json:"totalNodes"`
ClientInitialized bool `json:"clientInitialized"`
Timestamp string `json:"timestamp"`
}{
Success: true,
Message: fmt.Sprintf("Randomly selected new primary node: %s", newPrimary),
PrimaryNode: newPrimary,
TotalNodes: len(nodes),
ClientInitialized: true,
Timestamp: time.Now().Format(time.RFC3339),
}
json.NewEncoder(w).Encode(response)
}
// POST /api/discovery/primary/{ip}
func (hs *HTTPServer) setPrimaryNode(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
requestedIP := vars["ip"]
if err := hs.nodeDiscovery.SetPrimaryNode(requestedIP); err != nil {
http.Error(w, fmt.Sprintf(`{"error": "Node not found", "message": "Node with IP %s has not been discovered"}`, requestedIP), http.StatusNotFound)
return
}
response := struct {
Success bool `json:"success"`
Message string `json:"message"`
PrimaryNode string `json:"primaryNode"`
ClientInitialized bool `json:"clientInitialized"`
}{
Success: true,
Message: fmt.Sprintf("Primary node set to %s", requestedIP),
PrimaryNode: requestedIP,
ClientInitialized: true,
}
json.NewEncoder(w).Encode(response)
}
// GET /api/cluster/members
func (hs *HTTPServer) getClusterMembers(w http.ResponseWriter, r *http.Request) {
result, err := hs.performWithFailover(func(client *client.SporeClient) (interface{}, error) {
return client.GetClusterStatus()
})
if err != nil {
log.WithError(err).Error("Error fetching cluster members")
http.Error(w, fmt.Sprintf(`{"error": "Failed to fetch cluster members", "message": "%s"}`, err.Error()), http.StatusBadGateway)
return
}
json.NewEncoder(w).Encode(result)
}
// POST /api/cluster/refresh
func (hs *HTTPServer) refreshCluster(w http.ResponseWriter, r *http.Request) {
var requestBody struct {
Reason string `json:"reason"`
}
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil && err.Error() != "EOF" {
http.Error(w, `{"error": "Invalid JSON", "message": "Failed to parse request body"}`, http.StatusBadRequest)
return
}
reason := requestBody.Reason
if reason == "" {
reason = "manual_refresh"
}
log.WithField("reason", reason).Info("Manual cluster refresh triggered")
response := struct {
Success bool `json:"success"`
Message string `json:"message"`
Reason string `json:"reason"`
WSclients int `json:"wsClients"`
}{
Success: true,
Message: "Cluster refresh triggered",
Reason: reason,
WSclients: hs.webSocketServer.GetClientCount(),
}
json.NewEncoder(w).Encode(response)
}
// GET /api/tasks/status
func (hs *HTTPServer) getTaskStatus(w http.ResponseWriter, r *http.Request) {
ip := r.URL.Query().Get("ip")
if ip != "" {
client := hs.getSporeClient(ip)
result, err := client.GetTaskStatus()
if err != nil {
log.WithError(err).Error("Error fetching task status from specific node")
http.Error(w, fmt.Sprintf(`{"error": "Failed to fetch task status from node", "message": "%s"}`, err.Error()), http.StatusInternalServerError)
return
}
json.NewEncoder(w).Encode(result)
return
}
result, err := hs.performWithFailover(func(client *client.SporeClient) (interface{}, error) {
return client.GetTaskStatus()
})
if err != nil {
log.WithError(err).Error("Error fetching task status")
http.Error(w, fmt.Sprintf(`{"error": "Failed to fetch task status", "message": "%s"}`, err.Error()), http.StatusBadGateway)
return
}
json.NewEncoder(w).Encode(result)
}
// GET /api/node/status
func (hs *HTTPServer) getNodeStatus(w http.ResponseWriter, r *http.Request) {
result, err := hs.performWithFailover(func(client *client.SporeClient) (interface{}, error) {
return client.GetSystemStatus()
})
if err != nil {
log.WithError(err).Error("Error fetching system status")
http.Error(w, fmt.Sprintf(`{"error": "Failed to fetch system status", "message": "%s"}`, err.Error()), http.StatusBadGateway)
return
}
json.NewEncoder(w).Encode(result)
}
// GET /api/node/status/{ip}
func (hs *HTTPServer) getNodeStatusByIP(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
nodeIP := vars["ip"]
client := hs.getSporeClient(nodeIP)
result, err := client.GetSystemStatus()
if err != nil {
log.WithError(err).Error("Error fetching status from specific node")
http.Error(w, fmt.Sprintf(`{"error": "Failed to fetch status from node %s", "message": "%s"}`, nodeIP, err.Error()), http.StatusInternalServerError)
return
}
json.NewEncoder(w).Encode(result)
}
// GET /api/node/endpoints
func (hs *HTTPServer) getNodeEndpoints(w http.ResponseWriter, r *http.Request) {
ip := r.URL.Query().Get("ip")
if ip != "" {
client := hs.getSporeClient(ip)
result, err := client.GetCapabilities()
if err != nil {
log.WithError(err).Error("Error fetching endpoints from specific node")
http.Error(w, fmt.Sprintf(`{"error": "Failed to fetch endpoints from node", "message": "%s"}`, err.Error()), http.StatusInternalServerError)
return
}
json.NewEncoder(w).Encode(result)
return
}
result, err := hs.performWithFailover(func(client *client.SporeClient) (interface{}, error) {
return client.GetCapabilities()
})
if err != nil {
log.WithError(err).Error("Error fetching capabilities")
http.Error(w, fmt.Sprintf(`{"error": "Failed to fetch capabilities", "message": "%s"}`, err.Error()), http.StatusBadGateway)
return
}
json.NewEncoder(w).Encode(result)
}
// POST /api/node/update
func (hs *HTTPServer) updateNodeFirmware(w http.ResponseWriter, r *http.Request) {
nodeIP := r.URL.Query().Get("ip")
if nodeIP == "" {
nodeIP = r.Header.Get("X-Node-IP")
}
if nodeIP == "" {
http.Error(w, `{"error": "Node IP address is required", "message": "Please provide the target node IP address"}`, http.StatusBadRequest)
return
}
// Parse multipart form
err := r.ParseMultipartForm(50 << 20) // 50MB limit
if err != nil {
log.WithError(err).Error("Error parsing multipart form")
http.Error(w, `{"error": "Failed to parse form", "message": "Error parsing multipart form data"}`, http.StatusBadRequest)
return
}
file, fileHeader, err := r.FormFile("file")
if err != nil {
log.WithError(err).Error("No file found in form")
http.Error(w, `{"error": "No file data received", "message": "Please select a firmware file to upload"}`, http.StatusBadRequest)
return
}
defer file.Close()
// Get the original filename
filename := fileHeader.Filename
if filename == "" {
filename = "firmware.bin"
}
// Read file data
fileData := make([]byte, 0)
buffer := make([]byte, 1024)
for {
n, err := file.Read(buffer)
if n > 0 {
fileData = append(fileData, buffer[:n]...)
}
if err != nil {
if err.Error() == "EOF" {
break
}
log.WithError(err).Error("Error reading file data")
http.Error(w, `{"error": "Failed to read file", "message": "Error reading uploaded file data"}`, http.StatusInternalServerError)
return
}
}
log.WithFields(log.Fields{
"node_ip": nodeIP,
"file_size": len(fileData),
}).Info("Firmware upload received")
client := hs.getSporeClient(nodeIP)
result, err := client.UpdateFirmware(fileData, filename)
if err != nil {
log.WithError(err).Error("Error uploading firmware")
http.Error(w, fmt.Sprintf(`{"error": "Failed to upload firmware", "message": "%s"}`, err.Error()), http.StatusInternalServerError)
return
}
// Check if the device reported a failure
if result.Status == "FAIL" {
log.WithField("message", result.Message).Error("Device reported firmware update failure")
http.Error(w, fmt.Sprintf(`{"success": false, "error": "Firmware update failed", "message": "%s"}`, result.Message), http.StatusBadRequest)
return
}
response := struct {
Success bool `json:"success"`
Message string `json:"message"`
NodeIP string `json:"nodeIp"`
FileSize int `json:"fileSize"`
Filename string `json:"filename"`
Result interface{} `json:"result"`
}{
Success: true,
Message: "Firmware uploaded successfully",
NodeIP: nodeIP,
FileSize: len(fileData),
Filename: filename,
Result: result,
}
json.NewEncoder(w).Encode(response)
}
// POST /api/proxy-call
func (hs *HTTPServer) proxyCall(w http.ResponseWriter, r *http.Request) {
var requestBody struct {
IP string `json:"ip"`
Method string `json:"method"`
URI string `json:"uri"`
Params []map[string]interface{} `json:"params"`
}
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
http.Error(w, `{"error": "Invalid JSON", "message": "Failed to parse request body"}`, http.StatusBadRequest)
return
}
if requestBody.IP == "" || requestBody.Method == "" || requestBody.URI == "" {
http.Error(w, `{"error": "Missing required fields", "message": "Required: ip, method, uri"}`, http.StatusBadRequest)
return
}
// Convert params to map for client
params := make(map[string]interface{})
for _, param := range requestBody.Params {
if name, ok := param["name"].(string); ok {
// Create parameter object preserving UI-provided metadata
paramObj := map[string]interface{}{
"location": "body", // default location
"type": "string", // default type
}
// Preserve the UI's location and type information
if location, ok := param["location"].(string); ok && location != "" {
paramObj["location"] = location
}
if paramType, ok := param["type"].(string); ok && paramType != "" {
paramObj["type"] = paramType
}
// Extract the actual value from the parameter object
if value, ok := param["value"]; ok {
paramObj["value"] = value
} else {
paramObj["value"] = param
}
// Keep the value as-is, don't try to auto-detect JSON
// The UI will specify the correct type, and the client will handle it appropriately
params[name] = paramObj
}
}
client := hs.getSporeClient(requestBody.IP)
resp, err := client.ProxyCall(requestBody.Method, requestBody.URI, params)
if err != nil {
log.WithError(err).Error("Error in proxy call")
http.Error(w, fmt.Sprintf(`{"error": "Proxy call failed", "message": "%s"}`, err.Error()), http.StatusInternalServerError)
return
}
defer resp.Body.Close()
// Read response
body, err := io.ReadAll(resp.Body)
if err != nil {
log.WithError(err).Error("Error reading proxy response")
http.Error(w, `{"error": "Failed to read response", "message": "Error reading upstream response"}`, http.StatusInternalServerError)
return
}
// Set appropriate content type
contentType := resp.Header.Get("Content-Type")
if contentType != "" {
w.Header().Set("Content-Type", contentType)
}
// Set CORS headers for proxy responses
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
// Set status code
w.WriteHeader(resp.StatusCode)
// For consistency with frontend expectations, wrap response in data field for JSON responses
if contentType != "" && strings.Contains(contentType, "application/json") {
// Try to parse and re-wrap the JSON response
var jsonResp interface{}
if err := json.Unmarshal(body, &jsonResp); err == nil {
wrappedResp := map[string]interface{}{
"data": jsonResp,
"status": resp.StatusCode,
}
body, _ = json.Marshal(wrappedResp)
}
}
// Write response body
w.Write(body)
}
// POST /api/test/websocket
func (hs *HTTPServer) testWebSocket(w http.ResponseWriter, r *http.Request) {
log.Info("Manual WebSocket test triggered")
response := struct {
Success bool `json:"success"`
Message string `json:"message"`
WSclients int `json:"websocketClients"`
TotalNodes int `json:"totalNodes"`
}{
Success: true,
Message: "WebSocket test broadcast sent",
WSclients: hs.webSocketServer.GetClientCount(),
TotalNodes: len(hs.nodeDiscovery.GetNodes()),
}
json.NewEncoder(w).Encode(response)
}
// GET /api/health
func (hs *HTTPServer) healthCheck(w http.ResponseWriter, r *http.Request) {
primaryNode := hs.nodeDiscovery.GetPrimaryNode()
nodes := hs.nodeDiscovery.GetNodes()
clusterStatus := hs.nodeDiscovery.GetClusterStatus()
health := struct {
Status string `json:"status"`
Timestamp string `json:"timestamp"`
Services map[string]bool `json:"services"`
Cluster map[string]interface{} `json:"cluster"`
}{
Status: "healthy",
Timestamp: time.Now().Format(time.RFC3339),
Services: map[string]bool{
"http": true,
"udp": clusterStatus.ServerRunning,
"sporeClient": primaryNode != "",
},
Cluster: map[string]interface{}{
"totalNodes": clusterStatus.TotalNodes,
"primaryNode": clusterStatus.PrimaryNode,
"udpPort": clusterStatus.UDPPort,
"serverRunning": clusterStatus.ServerRunning,
},
}
// Mark as degraded if no nodes discovered
if len(nodes) == 0 {
health.Status = "degraded"
}
// Mark as degraded if no client initialized
if primaryNode == "" {
health.Status = "degraded"
}
statusCode := http.StatusOK
if health.Status != "healthy" {
statusCode = http.StatusServiceUnavailable
}
w.WriteHeader(statusCode)
json.NewEncoder(w).Encode(health)
}