From 9eb31437fc299b78db5fda7dc51d1dd5dad26bff Mon Sep 17 00:00:00 2001 From: 0x1d Date: Sun, 19 Oct 2025 22:30:53 +0200 Subject: [PATCH] fix: websocket race-condition and firmware upload --- internal/server/server.go | 124 +++++++++++++++++++++----------- internal/websocket/websocket.go | 14 +++- pkg/client/client.go | 35 ++++++++- 3 files changed, 124 insertions(+), 49 deletions(-) diff --git a/internal/server/server.go b/internal/server/server.go index d5d6bab..28e054d 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -491,7 +491,9 @@ func (hs *HTTPServer) updateNodeFirmware(w http.ResponseWriter, r *http.Request) } if nodeIP == "" { - http.Error(w, `{"error": "Node IP address is required", "message": "Please provide the target node IP address"}`, http.StatusBadRequest) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{"error": "Node IP address is required", "message": "Please provide the target node IP address"}`)) return } @@ -499,14 +501,18 @@ func (hs *HTTPServer) updateNodeFirmware(w http.ResponseWriter, r *http.Request) err := r.ParseMultipartForm(50 << 20) // 50MB limit if err != nil { log.WithError(err).Error("Error parsing multipart form") - http.Error(w, `{"error": "Failed to parse form", "message": "Error parsing multipart form data"}`, http.StatusBadRequest) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{"error": "Failed to parse form", "message": "Error parsing multipart form data"}`)) return } file, fileHeader, err := r.FormFile("file") if err != nil { log.WithError(err).Error("No file found in form") - http.Error(w, `{"error": "No file data received", "message": "Please select a firmware file to upload"}`, http.StatusBadRequest) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{"error": "No file data received", "message": "Please select a firmware file to upload"}`)) return } defer file.Close() @@ -517,22 +523,14 @@ func (hs *HTTPServer) updateNodeFirmware(w http.ResponseWriter, r *http.Request) filename = "firmware.bin" } - // Read file data - fileData := make([]byte, 0) - buffer := make([]byte, 1024) - for { - n, err := file.Read(buffer) - if n > 0 { - fileData = append(fileData, buffer[:n]...) - } - if err != nil { - if err.Error() == "EOF" { - break - } - log.WithError(err).Error("Error reading file data") - http.Error(w, `{"error": "Failed to read file", "message": "Error reading uploaded file data"}`, http.StatusInternalServerError) - return - } + // Read file data efficiently + fileData, err := io.ReadAll(file) + if err != nil { + log.WithError(err).Error("Error reading file data") + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"error": "Failed to read file", "message": "Error reading uploaded file data"}`)) + return } log.WithFields(log.Fields{ @@ -540,38 +538,78 @@ func (hs *HTTPServer) updateNodeFirmware(w http.ResponseWriter, r *http.Request) "file_size": len(fileData), }).Info("Firmware upload received") - client := hs.getSporeClient(nodeIP) - result, err := client.UpdateFirmware(fileData, filename) - if err != nil { - log.WithError(err).Error("Error uploading firmware") - http.Error(w, fmt.Sprintf(`{"error": "Failed to upload firmware", "message": "%s"}`, err.Error()), http.StatusInternalServerError) - return - } - - // Check if the device reported a failure - if result.Status == "FAIL" { - log.WithField("message", result.Message).Error("Device reported firmware update failure") - http.Error(w, fmt.Sprintf(`{"success": false, "error": "Firmware update failed", "message": "%s"}`, result.Message), http.StatusBadRequest) - return - } - + // Send immediate acknowledgment to client response := struct { - Success bool `json:"success"` - Message string `json:"message"` - NodeIP string `json:"nodeIp"` - FileSize int `json:"fileSize"` - Filename string `json:"filename"` - Result interface{} `json:"result"` + Success bool `json:"success"` + Message string `json:"message"` + NodeIP string `json:"nodeIp"` + FileSize int `json:"fileSize"` + Filename string `json:"filename"` + Status string `json:"status"` }{ Success: true, - Message: "Firmware uploaded successfully", + Message: "Firmware upload received, processing...", NodeIP: nodeIP, FileSize: len(fileData), Filename: filename, - Result: result, + Status: "processing", } - json.NewEncoder(w).Encode(response) + log.WithFields(log.Fields{ + "node_ip": nodeIP, + "file_size": len(fileData), + "filename": filename, + }).Info("Sending immediate acknowledgment to client") + + // Set response headers to ensure immediate delivery + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Cache-Control", "no-cache") + + if err := json.NewEncoder(w).Encode(response); err != nil { + log.WithError(err).Error("Failed to encode firmware upload acknowledgment") + return + } + + log.WithFields(log.Fields{ + "node_ip": nodeIP, + "status": "acknowledgment_sent", + "response": response, + }).Debug("Firmware upload acknowledgment sent to client") + + // Flush the response to ensure it's sent immediately + if f, ok := w.(http.Flusher); ok { + f.Flush() + log.WithField("node_ip", nodeIP).Debug("Acknowledgment flushed to client") + } + + // Now process the firmware upload in the background + go func() { + client := hs.getSporeClient(nodeIP) + result, err := client.UpdateFirmware(fileData, filename) + if err != nil { + log.WithFields(log.Fields{ + "node_ip": nodeIP, + "error": err.Error(), + }).Error("Error uploading firmware to device") + return + } + + // Check if the device reported a failure + if result.Status == "FAIL" { + log.WithFields(log.Fields{ + "node_ip": nodeIP, + "message": result.Message, + }).Error("Device reported firmware update failure") + return + } + + log.WithFields(log.Fields{ + "node_ip": nodeIP, + "file_size": len(fileData), + "filename": filename, + "result": result.Status, + }).Info("Firmware upload completed successfully") + }() } // POST /api/proxy-call diff --git a/internal/websocket/websocket.go b/internal/websocket/websocket.go index 8ddd7c9..3e7911b 100644 --- a/internal/websocket/websocket.go +++ b/internal/websocket/websocket.go @@ -26,6 +26,7 @@ type WebSocketServer struct { sporeClients map[string]*client.SporeClient clients map[*websocket.Conn]bool mutex sync.RWMutex + writeMutex sync.Mutex // Mutex to serialize writes to WebSocket connections logger *log.Logger } @@ -214,8 +215,11 @@ func (wss *WebSocketServer) broadcastClusterUpdate() { "prep_time": broadcastTime.Sub(startTime), }).Debug("Broadcasting cluster update to WebSocket clients") - // Send to all clients + // Send to all clients with write synchronization var failedClients int + wss.writeMutex.Lock() + defer wss.writeMutex.Unlock() + for _, client := range clients { client.SetWriteDeadline(time.Now().Add(5 * time.Second)) if err := client.WriteMessage(websocket.TextMessage, data); err != nil { @@ -224,7 +228,7 @@ func (wss *WebSocketServer) broadcastClusterUpdate() { } } - totalTime := time.Now().Sub(startTime) + totalTime := time.Since(startTime) wss.logger.WithFields(log.Fields{ "clients": len(clients), "failed_clients": failedClients, @@ -263,6 +267,10 @@ func (wss *WebSocketServer) broadcastNodeDiscovery(nodeIP, action string) { return } + // Send to all clients with write synchronization + wss.writeMutex.Lock() + defer wss.writeMutex.Unlock() + for _, client := range clients { client.SetWriteDeadline(time.Now().Add(5 * time.Second)) if err := client.WriteMessage(websocket.TextMessage, data); err != nil { @@ -302,7 +310,7 @@ func (wss *WebSocketServer) updateLocalNodesWithAPI(apiMembers []client.ClusterM wss.logger.WithField("members", len(apiMembers)).Debug("Updating local nodes with API data") for _, member := range apiMembers { - if member.Labels != nil && len(member.Labels) > 0 { + if len(member.Labels) > 0 { wss.logger.WithFields(log.Fields{ "ip": member.IP, "labels": member.Labels, diff --git a/pkg/client/client.go b/pkg/client/client.go index 4b778cc..b7901cb 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -225,20 +225,49 @@ func (c *SporeClient) UpdateFirmware(firmwareData []byte, filename string) (*Fir Timeout: 5 * time.Minute, // 5 minutes for firmware uploads } + log.WithFields(log.Fields{ + "node_ip": c.BaseURL, + "status": "sending_firmware", + }).Debug("Sending firmware to SPORE device") + resp, err := firmwareClient.Do(req) if err != nil { + log.WithFields(log.Fields{ + "node_ip": c.BaseURL, + "error": err.Error(), + }).Error("Failed to send firmware request to SPORE device") return nil, fmt.Errorf("failed to upload firmware: %w", err) } defer resp.Body.Close() + log.WithFields(log.Fields{ + "node_ip": c.BaseURL, + "status_code": resp.StatusCode, + "headers": resp.Header, + }).Debug("Received response from SPORE device") + if resp.StatusCode != http.StatusOK { + // Only try to read body for error cases body, _ := io.ReadAll(resp.Body) + log.WithFields(log.Fields{ + "node_ip": c.BaseURL, + "status": resp.StatusCode, + "error_body": string(body), + }).Error("SPORE device reported firmware upload failure") return nil, fmt.Errorf("firmware update failed with status %d: %s", resp.StatusCode, string(body)) } - var updateResponse FirmwareUpdateResponse - if err := json.NewDecoder(resp.Body).Decode(&updateResponse); err != nil { - return nil, fmt.Errorf("failed to decode firmware update response: %w", err) + // For successful firmware uploads, don't try to read the response body + // The SPORE device restarts immediately after sending the response, so reading the body + // would cause the connection to hang or timeout + log.WithFields(log.Fields{ + "node_ip": c.BaseURL, + "status": "success_no_body", + }).Info("Firmware upload completed successfully (device restarting)") + + updateResponse := FirmwareUpdateResponse{ + Status: "OK", + Message: "Firmware update completed successfully", } return &updateResponse, nil