From 9c86e215fe983f03fb56a1d88ef2441c343e4123 Mon Sep 17 00:00:00 2001 From: 0x1d Date: Tue, 21 Oct 2025 21:02:28 +0200 Subject: [PATCH] feat: rollout --- docs/Rollout.md | 14 + internal/server/server.go | 446 ++++++++++++++++++++++++++++++++ internal/websocket/websocket.go | 105 +++++++- pkg/client/client.go | 40 +++ pkg/registry/registry.go | 285 ++++++++++++++++++++ 5 files changed, 883 insertions(+), 7 deletions(-) create mode 100644 docs/Rollout.md create mode 100644 pkg/registry/registry.go diff --git a/docs/Rollout.md b/docs/Rollout.md new file mode 100644 index 0000000..a84fcb2 --- /dev/null +++ b/docs/Rollout.md @@ -0,0 +1,14 @@ +# Rollout + +The rollout feature works together with the spore-registry. +It provides an endpoint `/cluster/node/versions` to determin which version are installed on which nodes through the `version` label. +A rollout can be started by calling the `/rollout` endpoint and providing a set of labels. +The endpoint will then search the the corresponding firmware in the spore-registry and checks the cluster members that match the labels. +The gateway will then upload the firmware that was found to the matching cluster members in the background. Rollout and upload progress is sent through websocket. +Before the upload starts, the `version` label on the member node is updated with the firmware version from the registry. + +The spore-ui provides a rollout button on each firmware version. When clicked, the existing drawer is shown with the Rollout panel. +The gateway is consulted (endpoint `/cluster/node/versions`) o return the list of matching members that are affected by the rollout and displayed inside the Rollout panel. +The button `Rollout` will, once clicked, trigger the `/rollout` endpoint with the label set of the selected firmware that needs to be rolled out. +Rollout and upload progress is received through websocket and the Rollout panel updated in realtime. +Any UI interaction is blocked during rollout and the UI behaves like the Firmware Deploy on the cluster view (also with backdrop and info message). \ No newline at end of file diff --git a/internal/server/server.go b/internal/server/server.go index 72086a0..b03deca 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -7,11 +7,13 @@ import ( "io" "net/http" "strings" + "sync" "time" "spore-gateway/internal/discovery" "spore-gateway/internal/websocket" "spore-gateway/pkg/client" + "spore-gateway/pkg/registry" "github.com/gorilla/mux" log "github.com/sirupsen/logrus" @@ -24,6 +26,7 @@ type HTTPServer struct { nodeDiscovery *discovery.NodeDiscovery sporeClients map[string]*client.SporeClient webSocketServer *websocket.WebSocketServer + registryClient *registry.RegistryClient server *http.Server } @@ -32,12 +35,16 @@ func NewHTTPServer(port string, nodeDiscovery *discovery.NodeDiscovery) *HTTPSer // Initialize WebSocket server wsServer := websocket.NewWebSocketServer(nodeDiscovery) + // Initialize registry client + registryClient := registry.NewRegistryClient("http://localhost:3002") + hs := &HTTPServer{ port: port, router: mux.NewRouter(), nodeDiscovery: nodeDiscovery, sporeClients: make(map[string]*client.SporeClient), webSocketServer: wsServer, + registryClient: registryClient, } hs.setupRoutes() @@ -122,6 +129,8 @@ func (hs *HTTPServer) setupRoutes() { // Cluster endpoints api.HandleFunc("/cluster/members", hs.getClusterMembers).Methods("GET") api.HandleFunc("/cluster/refresh", hs.refreshCluster).Methods("POST", "OPTIONS") + api.HandleFunc("/cluster/node/versions", hs.getClusterNodeVersions).Methods("GET") + api.HandleFunc("/rollout", hs.startRollout).Methods("POST", "OPTIONS") // Task endpoints api.HandleFunc("/tasks/status", hs.getTaskStatus).Methods("GET") @@ -135,6 +144,13 @@ func (hs *HTTPServer) setupRoutes() { // Proxy endpoints api.HandleFunc("/proxy-call", hs.proxyCall).Methods("POST", "OPTIONS") + // Registry proxy endpoints + api.HandleFunc("/registry/health", hs.getRegistryHealth).Methods("GET") + api.HandleFunc("/registry/firmware", hs.listRegistryFirmware).Methods("GET") + api.HandleFunc("/registry/firmware", hs.uploadRegistryFirmware).Methods("POST", "OPTIONS") + api.HandleFunc("/registry/firmware/{name}/{version}", hs.downloadRegistryFirmware).Methods("GET") + api.HandleFunc("/registry/firmware/{name}/{version}", hs.updateRegistryFirmware).Methods("PUT", "OPTIONS") + // Test endpoints api.HandleFunc("/test/websocket", hs.testWebSocket).Methods("POST", "OPTIONS") @@ -787,3 +803,433 @@ func (hs *HTTPServer) healthCheck(w http.ResponseWriter, r *http.Request) { w.WriteHeader(statusCode) json.NewEncoder(w).Encode(health) } + +// RolloutRequest represents a rollout request +type RolloutRequest struct { + Firmware FirmwareInfo `json:"firmware"` + Nodes []NodeInfo `json:"nodes"` +} + +// FirmwareInfo represents firmware information +type FirmwareInfo struct { + Name string `json:"name"` + Version string `json:"version"` + Labels map[string]string `json:"labels"` +} + +// NodeInfo represents node information +type NodeInfo struct { + IP string `json:"ip"` + Version string `json:"version"` + Labels map[string]string `json:"labels"` +} + +// RolloutResponse represents a rollout response +type RolloutResponse struct { + Success bool `json:"success"` + Message string `json:"message"` + RolloutID string `json:"rolloutId"` + TotalNodes int `json:"totalNodes"` + FirmwareURL string `json:"firmwareUrl"` +} + +// NodeVersionInfo represents node version information +type NodeVersionInfo struct { + IP string `json:"ip"` + Version string `json:"version"` + Labels map[string]string `json:"labels"` +} + +// ClusterNodeVersionsResponse represents the response for cluster node versions +type ClusterNodeVersionsResponse struct { + Nodes []NodeVersionInfo `json:"nodes"` +} + +// GET /api/cluster/node/versions +func (hs *HTTPServer) getClusterNodeVersions(w http.ResponseWriter, r *http.Request) { + result, err := hs.performWithFailover(func(client *client.SporeClient) (interface{}, error) { + return client.GetClusterStatus() + }) + + if err != nil { + log.WithError(err).Error("Error fetching cluster members for versions") + http.Error(w, fmt.Sprintf(`{"error": "Failed to fetch cluster members", "message": "%s"}`, err.Error()), http.StatusBadGateway) + return + } + + clusterStatus, ok := result.(*client.ClusterStatusResponse) + if !ok { + http.Error(w, `{"error": "Invalid cluster status response"}`, http.StatusInternalServerError) + return + } + + // Extract version information from cluster members + var nodeVersions []NodeVersionInfo + for _, member := range clusterStatus.Members { + version := "unknown" + if v, exists := member.Labels["version"]; exists { + version = v + } + + nodeVersions = append(nodeVersions, NodeVersionInfo{ + IP: member.IP, + Version: version, + Labels: member.Labels, + }) + } + + response := ClusterNodeVersionsResponse{ + Nodes: nodeVersions, + } + + json.NewEncoder(w).Encode(response) +} + +// POST /api/rollout +func (hs *HTTPServer) startRollout(w http.ResponseWriter, r *http.Request) { + var request RolloutRequest + if err := json.NewDecoder(r.Body).Decode(&request); err != nil { + http.Error(w, `{"error": "Invalid JSON", "message": "Failed to parse request body"}`, http.StatusBadRequest) + return + } + + if len(request.Nodes) == 0 { + http.Error(w, `{"error": "No nodes", "message": "No nodes provided for rollout"}`, http.StatusBadRequest) + return + } + + if request.Firmware.Name == "" || request.Firmware.Version == "" { + http.Error(w, `{"error": "Missing firmware info", "message": "Firmware name and version are required"}`, http.StatusBadRequest) + return + } + + log.WithFields(log.Fields{ + "firmware_name": request.Firmware.Name, + "firmware_version": request.Firmware.Version, + "node_count": len(request.Nodes), + }).Info("Starting rollout") + + // Look up firmware in registry by name and version + firmware, err := hs.registryClient.FindFirmwareByNameAndVersion(request.Firmware.Name, request.Firmware.Version) + if err != nil { + log.WithError(err).Error("Failed to find firmware in registry") + http.Error(w, fmt.Sprintf(`{"error": "Firmware not found", "message": "No firmware found with name %s and version %s: %s"}`, request.Firmware.Name, request.Firmware.Version, err.Error()), http.StatusNotFound) + return + } + + firmwareURL := fmt.Sprintf("http://localhost:3002/firmware/%s/%s", firmware.Name, firmware.Version) + rolloutID := fmt.Sprintf("rollout_%d", time.Now().Unix()) + + log.WithFields(log.Fields{ + "rollout_id": rolloutID, + "matching_nodes": len(request.Nodes), + "firmware_name": request.Firmware.Name, + "firmware_version": request.Firmware.Version, + }).Info("Rollout initiated") + + // Send immediate response + response := RolloutResponse{ + Success: true, + Message: fmt.Sprintf("Rollout started for %d nodes", len(request.Nodes)), + RolloutID: rolloutID, + TotalNodes: len(request.Nodes), + FirmwareURL: firmwareURL, + } + + json.NewEncoder(w).Encode(response) + + // Start rollout process in background + go hs.processRollout(rolloutID, request.Nodes, request.Firmware) +} + +// nodeMatchesLabels checks if a node's labels match the rollout labels +func (hs *HTTPServer) nodeMatchesLabels(nodeLabels, rolloutLabels map[string]string) bool { + for key, value := range rolloutLabels { + if nodeValue, exists := nodeLabels[key]; !exists || nodeValue != value { + return false + } + } + return true +} + +// processRollout handles the actual rollout process in the background +func (hs *HTTPServer) processRollout(rolloutID string, nodes []NodeInfo, firmwareInfo FirmwareInfo) { + log.WithField("rollout_id", rolloutID).Info("Starting background rollout process") + + // Download firmware from registry + firmwareData, err := hs.registryClient.DownloadFirmware(firmwareInfo.Name, firmwareInfo.Version) + if err != nil { + log.WithError(err).Error("Failed to download firmware for rollout") + return + } + + log.WithFields(log.Fields{ + "rollout_id": rolloutID, + "firmware": fmt.Sprintf("%s/%s", firmwareInfo.Name, firmwareInfo.Version), + "size": len(firmwareData), + "total_nodes": len(nodes), + }).Info("Downloaded firmware for rollout") + + // Process nodes in parallel using goroutines + var wg sync.WaitGroup + + for i, node := range nodes { + wg.Add(1) + go func(nodeIndex int, node NodeInfo) { + defer wg.Done() + + log.WithFields(log.Fields{ + "rollout_id": rolloutID, + "node_ip": node.IP, + "progress": fmt.Sprintf("%d/%d", nodeIndex+1, len(nodes)), + }).Info("Processing node in rollout") + + // Update version label on the node before upload + client := hs.getSporeClient(node.IP) + + // Create updated labels with the new version + updatedLabels := make(map[string]string) + for k, v := range node.Labels { + updatedLabels[k] = v + } + + // Ensure version label is properly formatted + versionToSet := firmwareInfo.Version + // Remove 'v' prefix if present to ensure consistency + versionToSet = strings.TrimPrefix(versionToSet, "v") + updatedLabels["version"] = versionToSet + + log.WithFields(log.Fields{ + "rollout_id": rolloutID, + "node_ip": node.IP, + "old_version": node.Labels["version"], + "new_version": versionToSet, + "original_firmware_version": firmwareInfo.Version, + "all_labels": updatedLabels, + }).Info("Updating version label on node") + + // Broadcast label update progress + hs.webSocketServer.BroadcastRolloutProgress(rolloutID, node.IP, "updating_labels", nodeIndex+1, len(nodes)) + + if err := client.UpdateNodeLabels(updatedLabels); err != nil { + log.WithFields(log.Fields{ + "rollout_id": rolloutID, + "node_ip": node.IP, + "error": err.Error(), + }).Error("Failed to update version label on node") + + // Broadcast failure + hs.webSocketServer.BroadcastRolloutProgress(rolloutID, node.IP, "failed", nodeIndex+1, len(nodes)) + return + } + + log.WithFields(log.Fields{ + "rollout_id": rolloutID, + "node_ip": node.IP, + "version": versionToSet, + }).Info("Successfully updated version label on node") + + // Broadcast upload progress + hs.webSocketServer.BroadcastRolloutProgress(rolloutID, node.IP, "uploading", nodeIndex+1, len(nodes)) + + // Upload firmware to node + result, err := client.UpdateFirmware(firmwareData, fmt.Sprintf("%s-%s.bin", firmwareInfo.Name, firmwareInfo.Version)) + + if err != nil { + log.WithFields(log.Fields{ + "rollout_id": rolloutID, + "node_ip": node.IP, + "error": err.Error(), + }).Error("Failed to upload firmware to node") + + // Broadcast failure + hs.webSocketServer.BroadcastRolloutProgress(rolloutID, node.IP, "failed", nodeIndex+1, len(nodes)) + return + } + + // Check if the device reported a failure + if result.Status == "FAIL" { + log.WithFields(log.Fields{ + "rollout_id": rolloutID, + "node_ip": node.IP, + "message": result.Message, + }).Error("Device reported firmware update failure") + + // Broadcast failure + hs.webSocketServer.BroadcastRolloutProgress(rolloutID, node.IP, "failed", nodeIndex+1, len(nodes)) + return + } + + log.WithFields(log.Fields{ + "rollout_id": rolloutID, + "node_ip": node.IP, + "result": result.Status, + }).Info("Firmware upload completed successfully") + + // Broadcast completion + hs.webSocketServer.BroadcastRolloutProgress(rolloutID, node.IP, "completed", nodeIndex+1, len(nodes)) + }(i, node) + } + + // Wait for all goroutines to complete + wg.Wait() + + log.WithField("rollout_id", rolloutID).Info("Rollout process completed") +} + +// Registry proxy handlers + +// GET /api/registry/health +func (hs *HTTPServer) getRegistryHealth(w http.ResponseWriter, r *http.Request) { + health, err := hs.registryClient.GetHealth() + if err != nil { + log.WithError(err).Error("Failed to get registry health") + http.Error(w, fmt.Sprintf(`{"error": "Registry health check failed", "message": "%s"}`, err.Error()), http.StatusServiceUnavailable) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(health) +} + +// GET /api/registry/firmware +func (hs *HTTPServer) listRegistryFirmware(w http.ResponseWriter, r *http.Request) { + // Get query parameters + name := r.URL.Query().Get("name") + version := r.URL.Query().Get("version") + + firmwareList, err := hs.registryClient.ListFirmware() + if err != nil { + log.WithError(err).Error("Failed to list registry firmware") + http.Error(w, fmt.Sprintf(`{"error": "Failed to list firmware", "message": "%s"}`, err.Error()), http.StatusInternalServerError) + return + } + + // Filter by name and version if provided + if name != "" || version != "" { + filtered := []registry.GroupedFirmware{} + for _, group := range firmwareList { + if name != "" && group.Name != name { + continue + } + + filteredFirmware := []registry.FirmwareRecord{} + for _, firmware := range group.Firmware { + if version != "" && firmware.Version != version { + continue + } + filteredFirmware = append(filteredFirmware, firmware) + } + + if len(filteredFirmware) > 0 { + group.Firmware = filteredFirmware + filtered = append(filtered, group) + } + } + firmwareList = filtered + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(firmwareList) +} + +// POST /api/registry/firmware +func (hs *HTTPServer) uploadRegistryFirmware(w http.ResponseWriter, r *http.Request) { + // Parse multipart form + err := r.ParseMultipartForm(32 << 20) // 32MB max + if err != nil { + log.WithError(err).Error("Failed to parse multipart form") + http.Error(w, `{"error": "Invalid form data", "message": "Failed to parse multipart form"}`, http.StatusBadRequest) + return + } + + // Get metadata from form + metadataJSON := r.FormValue("metadata") + if metadataJSON == "" { + http.Error(w, `{"error": "Missing metadata", "message": "Metadata field is required"}`, http.StatusBadRequest) + return + } + + var metadata registry.FirmwareMetadata + if err := json.Unmarshal([]byte(metadataJSON), &metadata); err != nil { + log.WithError(err).Error("Invalid metadata JSON") + http.Error(w, `{"error": "Invalid metadata", "message": "Failed to parse metadata JSON"}`, http.StatusBadRequest) + return + } + + // Get firmware file + file, _, err := r.FormFile("firmware") + if err != nil { + log.WithError(err).Error("Missing firmware file") + http.Error(w, `{"error": "Missing firmware file", "message": "Firmware file is required"}`, http.StatusBadRequest) + return + } + defer file.Close() + + // Upload to registry + result, err := hs.registryClient.UploadFirmware(metadata, file) + if err != nil { + log.WithError(err).Error("Failed to upload firmware to registry") + http.Error(w, fmt.Sprintf(`{"error": "Upload failed", "message": "%s"}`, err.Error()), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(result) +} + +// GET /api/registry/firmware/{name}/{version} +func (hs *HTTPServer) downloadRegistryFirmware(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + name := vars["name"] + version := vars["version"] + + if name == "" || version == "" { + http.Error(w, `{"error": "Missing parameters", "message": "Name and version are required"}`, http.StatusBadRequest) + return + } + + firmwareData, err := hs.registryClient.DownloadFirmware(name, version) + if err != nil { + log.WithError(err).Error("Failed to download firmware from registry") + http.Error(w, fmt.Sprintf(`{"error": "Download failed", "message": "%s"}`, err.Error()), http.StatusNotFound) + return + } + + // Set appropriate headers for file download + w.Header().Set("Content-Type", "application/octet-stream") + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s-%s.bin\"", name, version)) + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(firmwareData))) + + w.Write(firmwareData) +} + +// PUT /api/registry/firmware/{name}/{version} +func (hs *HTTPServer) updateRegistryFirmware(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + name := vars["name"] + version := vars["version"] + + if name == "" || version == "" { + http.Error(w, `{"error": "Missing parameters", "message": "Name and version are required"}`, http.StatusBadRequest) + return + } + + var metadata registry.FirmwareMetadata + if err := json.NewDecoder(r.Body).Decode(&metadata); err != nil { + log.WithError(err).Error("Invalid metadata JSON") + http.Error(w, `{"error": "Invalid metadata", "message": "Failed to parse metadata JSON"}`, http.StatusBadRequest) + return + } + + // Update firmware metadata in registry + result, err := hs.registryClient.UpdateFirmwareMetadata(name, version, metadata) + if err != nil { + log.WithError(err).Error("Failed to update firmware metadata in registry") + http.Error(w, fmt.Sprintf(`{"error": "Update failed", "message": "%s"}`, err.Error()), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(result) +} diff --git a/internal/websocket/websocket.go b/internal/websocket/websocket.go index 87ec3dd..4fa7f8d 100644 --- a/internal/websocket/websocket.go +++ b/internal/websocket/websocket.go @@ -3,6 +3,7 @@ package websocket import ( "context" "encoding/json" + "fmt" "net/http" "sync" "time" @@ -90,13 +91,10 @@ func (wss *WebSocketServer) handleClient(conn *websocket.Conn) { ticker := time.NewTicker(54 * time.Second) defer ticker.Stop() - for { - select { - case <-ticker.C: - conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) - if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { - return - } + for range ticker.C { + conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { + return } } }() @@ -334,6 +332,99 @@ func (wss *WebSocketServer) BroadcastFirmwareUploadStatus(nodeIP, status, filena } } +// BroadcastRolloutProgress sends rollout progress updates to all clients +func (wss *WebSocketServer) BroadcastRolloutProgress(rolloutID, nodeIP, status string, current, total int) { + wss.mutex.RLock() + clients := make([]*websocket.Conn, 0, len(wss.clients)) + for client := range wss.clients { + clients = append(clients, client) + } + wss.mutex.RUnlock() + + if len(clients) == 0 { + return + } + + message := struct { + Type string `json:"type"` + RolloutID string `json:"rolloutId"` + NodeIP string `json:"nodeIp"` + Status string `json:"status"` + Current int `json:"current"` + Total int `json:"total"` + Progress int `json:"progress"` + Timestamp string `json:"timestamp"` + }{ + Type: "rollout_progress", + RolloutID: rolloutID, + NodeIP: nodeIP, + Status: status, + Current: current, + Total: total, + Progress: wss.calculateProgress(current, total, status), + Timestamp: time.Now().Format(time.RFC3339), + } + + data, err := json.Marshal(message) + if err != nil { + wss.logger.WithError(err).Error("Failed to marshal rollout progress") + return + } + + wss.logger.WithFields(log.Fields{ + "rollout_id": rolloutID, + "node_ip": nodeIP, + "status": status, + "progress": fmt.Sprintf("%d/%d", current, total), + "clients": len(clients), + }).Debug("Broadcasting rollout progress to WebSocket clients") + + // Send to all clients with write synchronization + wss.writeMutex.Lock() + defer wss.writeMutex.Unlock() + + for _, client := range clients { + client.SetWriteDeadline(time.Now().Add(5 * time.Second)) + if err := client.WriteMessage(websocket.TextMessage, data); err != nil { + wss.logger.WithError(err).Error("Failed to send rollout progress to client") + } + } +} + +// calculateProgress calculates the correct progress percentage based on current status +func (wss *WebSocketServer) calculateProgress(current, total int, status string) int { + if total == 0 { + return 0 + } + + // Base progress is based on completed nodes + completedNodes := current - 1 + if status == "completed" { + completedNodes = current + } + + // Calculate base progress (completed nodes / total nodes) + baseProgress := float64(completedNodes) / float64(total) * 100 + + // If currently updating labels or uploading, add partial progress for the current node + if status == "updating_labels" { + // Add 25% of one node's progress (label update is quick) + nodeProgress := 100.0 / float64(total) * 0.25 + baseProgress += nodeProgress + } else if status == "uploading" { + // Add 50% of one node's progress (so uploading shows as halfway through that node) + nodeProgress := 100.0 / float64(total) * 0.5 + baseProgress += nodeProgress + } + + // Ensure we don't exceed 100% + if baseProgress > 100 { + baseProgress = 100 + } + + return int(baseProgress) +} + // getCurrentClusterMembers fetches real cluster data from SPORE nodes func (wss *WebSocketServer) getCurrentClusterMembers() ([]client.ClusterMember, error) { nodes := wss.nodeDiscovery.GetNodes() diff --git a/pkg/client/client.go b/pkg/client/client.go index b7901cb..83ad04c 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -273,6 +273,46 @@ func (c *SporeClient) UpdateFirmware(firmwareData []byte, filename string) (*Fir return &updateResponse, nil } +// UpdateNodeLabels updates the labels on a SPORE node +func (c *SporeClient) UpdateNodeLabels(labels map[string]string) error { + targetURL := fmt.Sprintf("%s/api/node/config", c.BaseURL) + + // Convert labels to JSON + labelsJSON, err := json.Marshal(labels) + if err != nil { + return fmt.Errorf("failed to marshal labels: %w", err) + } + + // Create form data + data := url.Values{} + data.Set("labels", string(labelsJSON)) + + req, err := http.NewRequest("POST", targetURL, strings.NewReader(data.Encode())) + if err != nil { + return fmt.Errorf("failed to create labels update request: %w", err) + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return fmt.Errorf("failed to update node labels: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("node labels update failed with status %d: %s", resp.StatusCode, string(body)) + } + + log.WithFields(log.Fields{ + "node_ip": c.BaseURL, + "labels": labels, + }).Info("Node labels updated successfully") + + return nil +} + // ProxyCall makes a generic HTTP request to a SPORE node endpoint func (c *SporeClient) ProxyCall(method, uri string, params map[string]interface{}) (*http.Response, error) { // Build target URL diff --git a/pkg/registry/registry.go b/pkg/registry/registry.go new file mode 100644 index 0000000..d67a873 --- /dev/null +++ b/pkg/registry/registry.go @@ -0,0 +1,285 @@ +package registry + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "mime/multipart" + "net/http" + "time" + + log "github.com/sirupsen/logrus" +) + +// RegistryClient represents a client for communicating with the SPORE registry +type RegistryClient struct { + BaseURL string + HTTPClient *http.Client +} + +// NewRegistryClient creates a new registry API client +func NewRegistryClient(baseURL string) *RegistryClient { + return &RegistryClient{ + BaseURL: baseURL, + HTTPClient: &http.Client{ + Timeout: 30 * time.Second, + }, + } +} + +// FirmwareRecord represents a firmware record from the registry +type FirmwareRecord struct { + Name string `json:"name"` + Version string `json:"version"` + Size int64 `json:"size"` + Labels map[string]string `json:"labels"` + Path string `json:"download_url"` +} + +// GroupedFirmware represents firmware grouped by name +type GroupedFirmware struct { + Name string `json:"name"` + Firmware []FirmwareRecord `json:"firmware"` +} + +// FindFirmwareByNameAndVersion finds firmware in the registry by name and version +func (c *RegistryClient) FindFirmwareByNameAndVersion(name, version string) (*FirmwareRecord, error) { + // Get all firmware from registry + firmwareList, err := c.ListFirmware() + if err != nil { + return nil, fmt.Errorf("failed to list firmware: %w", err) + } + + // Search through all firmware groups + for _, group := range firmwareList { + if group.Name == name { + for _, firmware := range group.Firmware { + if firmware.Version == version { + return &firmware, nil + } + } + } + } + + return nil, fmt.Errorf("no firmware found with name %s and version %s", name, version) +} + +// GetHealth checks the health of the registry +func (c *RegistryClient) GetHealth() (map[string]interface{}, error) { + url := fmt.Sprintf("%s/health", c.BaseURL) + + resp, err := c.HTTPClient.Get(url) + if err != nil { + return nil, fmt.Errorf("failed to get registry health: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("registry health check failed with status %d", resp.StatusCode) + } + + var health map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&health); err != nil { + return nil, fmt.Errorf("failed to decode health response: %w", err) + } + + return health, nil +} + +// UploadFirmware uploads firmware to the registry +func (c *RegistryClient) UploadFirmware(metadata FirmwareMetadata, firmwareFile io.Reader) (map[string]interface{}, error) { + url := fmt.Sprintf("%s/firmware", c.BaseURL) + + // Create multipart form data + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + + // Add metadata + metadataJSON, err := json.Marshal(metadata) + if err != nil { + return nil, fmt.Errorf("failed to marshal metadata: %w", err) + } + + metadataPart, err := writer.CreateFormField("metadata") + if err != nil { + return nil, fmt.Errorf("failed to create metadata field: %w", err) + } + metadataPart.Write(metadataJSON) + + // Add firmware file + firmwarePart, err := writer.CreateFormFile("firmware", fmt.Sprintf("%s-%s.bin", metadata.Name, metadata.Version)) + if err != nil { + return nil, fmt.Errorf("failed to create firmware field: %w", err) + } + + if _, err := io.Copy(firmwarePart, firmwareFile); err != nil { + return nil, fmt.Errorf("failed to copy firmware data: %w", err) + } + + writer.Close() + + req, err := http.NewRequest("POST", url, body) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", writer.FormDataContentType()) + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to upload firmware: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("firmware upload failed with status %d: %s", resp.StatusCode, string(body)) + } + + var result map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to decode upload response: %w", err) + } + + return result, nil +} + +// UpdateFirmwareMetadata updates firmware metadata in the registry +func (c *RegistryClient) UpdateFirmwareMetadata(name, version string, metadata FirmwareMetadata) (map[string]interface{}, error) { + url := fmt.Sprintf("%s/firmware/%s/%s", c.BaseURL, name, version) + + metadataJSON, err := json.Marshal(metadata) + if err != nil { + return nil, fmt.Errorf("failed to marshal metadata: %w", err) + } + + req, err := http.NewRequest("PUT", url, bytes.NewBuffer(metadataJSON)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to update firmware metadata: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("firmware metadata update failed with status %d: %s", resp.StatusCode, string(body)) + } + + var result map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to decode update response: %w", err) + } + + return result, nil +} + +// FirmwareMetadata represents firmware metadata for uploads +type FirmwareMetadata struct { + Name string `json:"name"` + Version string `json:"version"` + Labels map[string]string `json:"labels"` +} + +// FindFirmwareByLabels finds firmware in the registry that matches the given labels +func (c *RegistryClient) FindFirmwareByLabels(labels map[string]string) (*FirmwareRecord, error) { + // Get all firmware from registry + firmwareList, err := c.ListFirmware() + if err != nil { + return nil, fmt.Errorf("failed to list firmware: %w", err) + } + + // Search through all firmware groups + for _, group := range firmwareList { + for _, firmware := range group.Firmware { + if c.firmwareMatchesLabels(firmware.Labels, labels) { + return &firmware, nil + } + } + } + + return nil, fmt.Errorf("no firmware found matching labels: %v", labels) +} + +// firmwareMatchesLabels checks if firmware labels match the rollout labels +func (c *RegistryClient) firmwareMatchesLabels(firmwareLabels, rolloutLabels map[string]string) bool { + for key, value := range rolloutLabels { + if firmwareValue, exists := firmwareLabels[key]; !exists || firmwareValue != value { + return false + } + } + return true +} + +// ListFirmware retrieves all firmware from the registry +func (c *RegistryClient) ListFirmware() ([]GroupedFirmware, error) { + url := fmt.Sprintf("%s/firmware", c.BaseURL) + + resp, err := c.HTTPClient.Get(url) + if err != nil { + return nil, fmt.Errorf("failed to get firmware list: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("firmware list request failed with status %d", resp.StatusCode) + } + + var firmwareList []GroupedFirmware + if err := json.NewDecoder(resp.Body).Decode(&firmwareList); err != nil { + return nil, fmt.Errorf("failed to decode firmware list response: %w", err) + } + + return firmwareList, nil +} + +// DownloadFirmware downloads firmware binary from the registry +func (c *RegistryClient) DownloadFirmware(name, version string) ([]byte, error) { + url := fmt.Sprintf("%s/firmware/%s/%s", c.BaseURL, name, version) + + resp, err := c.HTTPClient.Get(url) + if err != nil { + return nil, fmt.Errorf("failed to download firmware: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("firmware download request failed with status %d", resp.StatusCode) + } + + data, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read firmware data: %w", err) + } + + log.WithFields(log.Fields{ + "name": name, + "version": version, + "size": len(data), + }).Info("Downloaded firmware from registry") + + return data, nil +} + +// HealthCheck checks if the registry is healthy +func (c *RegistryClient) HealthCheck() error { + url := fmt.Sprintf("%s/health", c.BaseURL) + + resp, err := c.HTTPClient.Get(url) + if err != nil { + return fmt.Errorf("failed to check registry health: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("registry health check failed with status %d", resp.StatusCode) + } + + return nil +}