Files
spore-gateway/internal/websocket/websocket.go

434 lines
12 KiB
Go

package websocket
import (
"context"
"encoding/json"
"net/http"
"sync"
"time"
"spore-gateway/internal/discovery"
"spore-gateway/pkg/client"
"github.com/gorilla/websocket"
log "github.com/sirupsen/logrus"
)
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true // Allow connections from any origin in development
},
}
// WebSocketServer manages WebSocket connections and broadcasts
type WebSocketServer struct {
nodeDiscovery *discovery.NodeDiscovery
sporeClients map[string]*client.SporeClient
clients map[*websocket.Conn]bool
mutex sync.RWMutex
writeMutex sync.Mutex // Mutex to serialize writes to WebSocket connections
logger *log.Logger
}
// NewWebSocketServer creates a new WebSocket server
func NewWebSocketServer(nodeDiscovery *discovery.NodeDiscovery) *WebSocketServer {
wss := &WebSocketServer{
nodeDiscovery: nodeDiscovery,
sporeClients: make(map[string]*client.SporeClient),
clients: make(map[*websocket.Conn]bool),
logger: log.New(),
}
// Register callback for node updates
nodeDiscovery.AddCallback(wss.handleNodeUpdate)
return wss
}
// HandleWebSocket handles WebSocket upgrade and connection
func (wss *WebSocketServer) HandleWebSocket(w http.ResponseWriter, r *http.Request) error {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
wss.logger.WithError(err).Error("Failed to upgrade WebSocket connection")
return err
}
wss.mutex.Lock()
wss.clients[conn] = true
wss.mutex.Unlock()
wss.logger.Debug("WebSocket client connected")
// Send current cluster state to newly connected client
go wss.sendCurrentClusterState(conn)
// Handle client messages and disconnection
go wss.handleClient(conn)
return nil
}
// handleClient handles messages from a WebSocket client
func (wss *WebSocketServer) handleClient(conn *websocket.Conn) {
defer func() {
wss.mutex.Lock()
delete(wss.clients, conn)
wss.mutex.Unlock()
conn.Close()
wss.logger.Debug("WebSocket client disconnected")
}()
// Set read deadline and pong handler
conn.SetReadDeadline(time.Now().Add(60 * time.Second))
conn.SetPongHandler(func(string) error {
conn.SetReadDeadline(time.Now().Add(60 * time.Second))
return nil
})
// Start ping routine
go func() {
ticker := time.NewTicker(54 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
}
}
}()
// Read messages (we don't expect any, but this keeps the connection alive)
for {
_, _, err := conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
wss.logger.WithError(err).Error("WebSocket error")
}
break
}
}
}
// sendCurrentClusterState sends the current cluster state to a newly connected client
func (wss *WebSocketServer) sendCurrentClusterState(conn *websocket.Conn) {
nodes := wss.nodeDiscovery.GetNodes()
if len(nodes) == 0 {
return
}
// Get real cluster data from SPORE nodes
clusterData, err := wss.getCurrentClusterMembers()
if err != nil {
wss.logger.WithError(err).Error("Failed to get cluster data for WebSocket")
return
}
message := struct {
Type string `json:"type"`
Members []client.ClusterMember `json:"members"`
PrimaryNode string `json:"primaryNode"`
TotalNodes int `json:"totalNodes"`
Timestamp string `json:"timestamp"`
}{
Type: "cluster_update",
Members: clusterData,
PrimaryNode: wss.nodeDiscovery.GetPrimaryNode(),
TotalNodes: len(nodes),
Timestamp: time.Now().Format(time.RFC3339),
}
data, err := json.Marshal(message)
if err != nil {
wss.logger.WithError(err).Error("Failed to marshal cluster data")
return
}
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := conn.WriteMessage(websocket.TextMessage, data); err != nil {
wss.logger.WithError(err).Error("Failed to send initial cluster state")
}
}
// handleNodeUpdate is called when node information changes
func (wss *WebSocketServer) handleNodeUpdate(nodeIP, action string) {
wss.logger.WithFields(log.Fields{
"node_ip": nodeIP,
"action": action,
}).Debug("Node update received, broadcasting to WebSocket clients")
// Broadcast cluster update to all clients
wss.broadcastClusterUpdate()
// Also broadcast node discovery event
wss.broadcastNodeDiscovery(nodeIP, action)
}
// broadcastClusterUpdate sends cluster updates to all connected clients
func (wss *WebSocketServer) broadcastClusterUpdate() {
wss.mutex.RLock()
clients := make([]*websocket.Conn, 0, len(wss.clients))
for client := range wss.clients {
clients = append(clients, client)
}
wss.mutex.RUnlock()
if len(clients) == 0 {
return
}
startTime := time.Now()
// Get cluster members asynchronously
clusterData, err := wss.getCurrentClusterMembers()
if err != nil {
wss.logger.WithError(err).Error("Failed to get cluster data for broadcast")
return
}
message := struct {
Type string `json:"type"`
Members []client.ClusterMember `json:"members"`
PrimaryNode string `json:"primaryNode"`
TotalNodes int `json:"totalNodes"`
Timestamp string `json:"timestamp"`
}{
Type: "cluster_update",
Members: clusterData,
PrimaryNode: wss.nodeDiscovery.GetPrimaryNode(),
TotalNodes: len(wss.nodeDiscovery.GetNodes()),
Timestamp: time.Now().Format(time.RFC3339),
}
data, err := json.Marshal(message)
if err != nil {
wss.logger.WithError(err).Error("Failed to marshal cluster update")
return
}
broadcastTime := time.Now()
wss.logger.WithFields(log.Fields{
"clients": len(clients),
"prep_time": broadcastTime.Sub(startTime),
}).Debug("Broadcasting cluster update to WebSocket clients")
// Send to all clients with write synchronization
var failedClients int
wss.writeMutex.Lock()
defer wss.writeMutex.Unlock()
for _, client := range clients {
client.SetWriteDeadline(time.Now().Add(5 * time.Second))
if err := client.WriteMessage(websocket.TextMessage, data); err != nil {
wss.logger.WithError(err).Error("Failed to send cluster update to client")
failedClients++
}
}
totalTime := time.Since(startTime)
wss.logger.WithFields(log.Fields{
"clients": len(clients),
"failed_clients": failedClients,
"total_time": totalTime,
}).Debug("Cluster update broadcast completed")
}
// broadcastNodeDiscovery sends node discovery events to all clients
func (wss *WebSocketServer) broadcastNodeDiscovery(nodeIP, action string) {
wss.mutex.RLock()
clients := make([]*websocket.Conn, 0, len(wss.clients))
for client := range wss.clients {
clients = append(clients, client)
}
wss.mutex.RUnlock()
if len(clients) == 0 {
return
}
message := struct {
Type string `json:"type"`
Action string `json:"action"`
NodeIP string `json:"nodeIp"`
Timestamp string `json:"timestamp"`
}{
Type: "node_discovery",
Action: action,
NodeIP: nodeIP,
Timestamp: time.Now().Format(time.RFC3339),
}
data, err := json.Marshal(message)
if err != nil {
wss.logger.WithError(err).Error("Failed to marshal node discovery event")
return
}
// Send to all clients with write synchronization
wss.writeMutex.Lock()
defer wss.writeMutex.Unlock()
for _, client := range clients {
client.SetWriteDeadline(time.Now().Add(5 * time.Second))
if err := client.WriteMessage(websocket.TextMessage, data); err != nil {
wss.logger.WithError(err).Error("Failed to send node discovery event to client")
}
}
}
// BroadcastFirmwareUploadStatus sends firmware upload status updates to all clients
func (wss *WebSocketServer) BroadcastFirmwareUploadStatus(nodeIP, status, filename string, fileSize int) {
wss.mutex.RLock()
clients := make([]*websocket.Conn, 0, len(wss.clients))
for client := range wss.clients {
clients = append(clients, client)
}
wss.mutex.RUnlock()
if len(clients) == 0 {
return
}
message := struct {
Type string `json:"type"`
NodeIP string `json:"nodeIp"`
Status string `json:"status"`
Filename string `json:"filename"`
FileSize int `json:"fileSize"`
Timestamp string `json:"timestamp"`
}{
Type: "firmware_upload_status",
NodeIP: nodeIP,
Status: status,
Filename: filename,
FileSize: fileSize,
Timestamp: time.Now().Format(time.RFC3339),
}
data, err := json.Marshal(message)
if err != nil {
wss.logger.WithError(err).Error("Failed to marshal firmware upload status")
return
}
wss.logger.WithFields(log.Fields{
"node_ip": nodeIP,
"status": status,
"filename": filename,
"file_size": fileSize,
"clients": len(clients),
}).Debug("Broadcasting firmware upload status to WebSocket clients")
// Send to all clients with write synchronization
wss.writeMutex.Lock()
defer wss.writeMutex.Unlock()
for _, client := range clients {
client.SetWriteDeadline(time.Now().Add(5 * time.Second))
if err := client.WriteMessage(websocket.TextMessage, data); err != nil {
wss.logger.WithError(err).Error("Failed to send firmware upload status to client")
}
}
}
// getCurrentClusterMembers fetches real cluster data from SPORE nodes
func (wss *WebSocketServer) getCurrentClusterMembers() ([]client.ClusterMember, error) {
nodes := wss.nodeDiscovery.GetNodes()
if len(nodes) == 0 {
return []client.ClusterMember{}, nil
}
// Try to get real cluster data from primary node
primaryNode := wss.nodeDiscovery.GetPrimaryNode()
if primaryNode != "" {
client := wss.getSporeClient(primaryNode)
clusterStatus, err := client.GetClusterStatus()
if err == nil {
// Update local node data with API information
wss.updateLocalNodesWithAPI(clusterStatus.Members)
return clusterStatus.Members, nil
}
wss.logger.WithError(err).Error("Failed to get cluster status from primary node")
}
// Fallback to local data if API fails
return wss.getFallbackClusterMembers(), nil
}
// updateLocalNodesWithAPI updates local node data with information from API
func (wss *WebSocketServer) updateLocalNodesWithAPI(apiMembers []client.ClusterMember) {
// This would update the local node discovery with fresh API data
// For now, we'll just log that we received the data
wss.logger.WithField("members", len(apiMembers)).Debug("Updating local nodes with API data")
for _, member := range apiMembers {
if len(member.Labels) > 0 {
wss.logger.WithFields(log.Fields{
"ip": member.IP,
"labels": member.Labels,
}).Debug("API member labels")
}
}
}
// getFallbackClusterMembers returns local node data as fallback
func (wss *WebSocketServer) getFallbackClusterMembers() []client.ClusterMember {
nodes := wss.nodeDiscovery.GetNodes()
members := make([]client.ClusterMember, 0, len(nodes))
for _, node := range nodes {
member := client.ClusterMember{
IP: node.IP,
Hostname: node.Hostname,
Status: string(node.Status),
Latency: node.Latency,
LastSeen: node.LastSeen.Unix(),
Labels: node.Labels,
}
members = append(members, member)
}
return members
}
// getSporeClient gets or creates a SPORE client for a node
func (wss *WebSocketServer) getSporeClient(nodeIP string) *client.SporeClient {
if client, exists := wss.sporeClients[nodeIP]; exists {
return client
}
client := client.NewSporeClient("http://" + nodeIP)
wss.sporeClients[nodeIP] = client
return client
}
// GetClientCount returns the number of connected WebSocket clients
func (wss *WebSocketServer) GetClientCount() int {
wss.mutex.RLock()
defer wss.mutex.RUnlock()
return len(wss.clients)
}
// Shutdown gracefully shuts down the WebSocket server
func (wss *WebSocketServer) Shutdown(ctx context.Context) error {
wss.logger.Info("Shutting down WebSocket server")
wss.mutex.Lock()
clients := make([]*websocket.Conn, 0, len(wss.clients))
for client := range wss.clients {
clients = append(clients, client)
}
wss.mutex.Unlock()
// Close all client connections
for _, client := range clients {
client.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, "Server shutting down"))
client.Close()
}
return nil
}