feat: mock gateway
This commit is contained in:
475
internal/mock/websocket.go
Normal file
475
internal/mock/websocket.go
Normal file
@@ -0,0 +1,475 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true // Allow connections from any origin
|
||||
},
|
||||
}
|
||||
|
||||
// MockWebSocketServer manages WebSocket connections and mock broadcasts
|
||||
type MockWebSocketServer struct {
|
||||
discovery *MockNodeDiscovery
|
||||
clients map[*websocket.Conn]bool
|
||||
mutex sync.RWMutex
|
||||
writeMutex sync.Mutex
|
||||
shutdownChan chan struct{}
|
||||
shutdownOnce sync.Once
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
// NewMockWebSocketServer creates a new mock WebSocket server
|
||||
func NewMockWebSocketServer(discovery *MockNodeDiscovery) *MockWebSocketServer {
|
||||
mws := &MockWebSocketServer{
|
||||
discovery: discovery,
|
||||
clients: make(map[*websocket.Conn]bool),
|
||||
shutdownChan: make(chan struct{}),
|
||||
logger: log.New(),
|
||||
}
|
||||
|
||||
// Register callback for node updates
|
||||
discovery.AddCallback(mws.handleNodeUpdate)
|
||||
|
||||
// Start periodic broadcasts
|
||||
go mws.startPeriodicBroadcasts()
|
||||
|
||||
return mws
|
||||
}
|
||||
|
||||
// HandleWebSocket handles WebSocket upgrade and connection
|
||||
func (mws *MockWebSocketServer) HandleWebSocket(w http.ResponseWriter, r *http.Request) error {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
mws.logger.WithError(err).Error("Failed to upgrade WebSocket connection")
|
||||
return err
|
||||
}
|
||||
|
||||
mws.mutex.Lock()
|
||||
mws.clients[conn] = true
|
||||
mws.mutex.Unlock()
|
||||
|
||||
mws.logger.Debug("Mock WebSocket client connected")
|
||||
|
||||
// Send current cluster state to newly connected client
|
||||
go mws.sendCurrentClusterState(conn)
|
||||
|
||||
// Handle client messages and disconnection
|
||||
go mws.handleClient(conn)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleClient handles messages from a WebSocket client
|
||||
func (mws *MockWebSocketServer) handleClient(conn *websocket.Conn) {
|
||||
defer func() {
|
||||
mws.mutex.Lock()
|
||||
delete(mws.clients, conn)
|
||||
mws.mutex.Unlock()
|
||||
conn.Close()
|
||||
mws.logger.Debug("Mock WebSocket client disconnected")
|
||||
}()
|
||||
|
||||
// Set read deadline and pong handler
|
||||
conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
conn.SetPongHandler(func(string) error {
|
||||
conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
return nil
|
||||
})
|
||||
|
||||
// Start ping routine
|
||||
go func() {
|
||||
ticker := time.NewTicker(54 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-mws.shutdownChan:
|
||||
return
|
||||
case <-ticker.C:
|
||||
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Read messages
|
||||
for {
|
||||
_, _, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
mws.logger.WithError(err).Error("WebSocket error")
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendCurrentClusterState sends the current cluster state to a newly connected client
|
||||
func (mws *MockWebSocketServer) sendCurrentClusterState(conn *websocket.Conn) {
|
||||
nodes := mws.discovery.GetNodes()
|
||||
if len(nodes) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Convert to mock format
|
||||
mockNodes := make(map[string]*NodeInfo)
|
||||
for ip, node := range nodes {
|
||||
mockNodes[ip] = &NodeInfo{
|
||||
IP: node.IP,
|
||||
Hostname: node.Hostname,
|
||||
Status: string(node.Status),
|
||||
Latency: node.Latency,
|
||||
LastSeen: node.LastSeen,
|
||||
Labels: node.Labels,
|
||||
}
|
||||
}
|
||||
|
||||
members := GenerateMockClusterMembers(mockNodes)
|
||||
|
||||
message := struct {
|
||||
Type string `json:"type"`
|
||||
Members interface{} `json:"members"`
|
||||
PrimaryNode string `json:"primaryNode"`
|
||||
TotalNodes int `json:"totalNodes"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
}{
|
||||
Type: "cluster_update",
|
||||
Members: members,
|
||||
PrimaryNode: mws.discovery.GetPrimaryNode(),
|
||||
TotalNodes: len(nodes),
|
||||
Timestamp: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
data, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
mws.logger.WithError(err).Error("Failed to marshal cluster data")
|
||||
return
|
||||
}
|
||||
|
||||
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
if err := conn.WriteMessage(websocket.TextMessage, data); err != nil {
|
||||
mws.logger.WithError(err).Error("Failed to send initial cluster state")
|
||||
}
|
||||
}
|
||||
|
||||
// handleNodeUpdate is called when node information changes
|
||||
func (mws *MockWebSocketServer) handleNodeUpdate(nodeIP, action string) {
|
||||
mws.logger.WithFields(log.Fields{
|
||||
"node_ip": nodeIP,
|
||||
"action": action,
|
||||
}).Debug("Mock node update received, broadcasting to WebSocket clients")
|
||||
|
||||
// Broadcast cluster update
|
||||
mws.BroadcastClusterUpdate()
|
||||
|
||||
// Also broadcast node discovery event
|
||||
mws.broadcastNodeDiscovery(nodeIP, action)
|
||||
}
|
||||
|
||||
// startPeriodicBroadcasts sends periodic updates to keep clients informed
|
||||
func (mws *MockWebSocketServer) startPeriodicBroadcasts() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-mws.shutdownChan:
|
||||
return
|
||||
case <-ticker.C:
|
||||
mws.BroadcastClusterUpdate()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BroadcastClusterUpdate sends cluster updates to all connected clients
|
||||
func (mws *MockWebSocketServer) BroadcastClusterUpdate() {
|
||||
mws.mutex.RLock()
|
||||
clients := make([]*websocket.Conn, 0, len(mws.clients))
|
||||
for client := range mws.clients {
|
||||
clients = append(clients, client)
|
||||
}
|
||||
mws.mutex.RUnlock()
|
||||
|
||||
if len(clients) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
nodes := mws.discovery.GetNodes()
|
||||
|
||||
// Convert to mock format
|
||||
mockNodes := make(map[string]*NodeInfo)
|
||||
for ip, node := range nodes {
|
||||
mockNodes[ip] = &NodeInfo{
|
||||
IP: node.IP,
|
||||
Hostname: node.Hostname,
|
||||
Status: string(node.Status),
|
||||
Latency: node.Latency,
|
||||
LastSeen: node.LastSeen,
|
||||
Labels: node.Labels,
|
||||
}
|
||||
}
|
||||
|
||||
members := GenerateMockClusterMembers(mockNodes)
|
||||
|
||||
message := struct {
|
||||
Type string `json:"type"`
|
||||
Members interface{} `json:"members"`
|
||||
PrimaryNode string `json:"primaryNode"`
|
||||
TotalNodes int `json:"totalNodes"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
}{
|
||||
Type: "cluster_update",
|
||||
Members: members,
|
||||
PrimaryNode: mws.discovery.GetPrimaryNode(),
|
||||
TotalNodes: len(nodes),
|
||||
Timestamp: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
data, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
mws.logger.WithError(err).Error("Failed to marshal cluster update")
|
||||
return
|
||||
}
|
||||
|
||||
mws.logger.WithField("clients", len(clients)).Debug("Broadcasting mock cluster update")
|
||||
|
||||
// Send to all clients with write synchronization
|
||||
mws.writeMutex.Lock()
|
||||
defer mws.writeMutex.Unlock()
|
||||
|
||||
for _, client := range clients {
|
||||
client.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
||||
if err := client.WriteMessage(websocket.TextMessage, data); err != nil {
|
||||
mws.logger.WithError(err).Error("Failed to send cluster update to client")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// broadcastNodeDiscovery sends node discovery events to all clients
|
||||
func (mws *MockWebSocketServer) broadcastNodeDiscovery(nodeIP, action string) {
|
||||
mws.mutex.RLock()
|
||||
clients := make([]*websocket.Conn, 0, len(mws.clients))
|
||||
for client := range mws.clients {
|
||||
clients = append(clients, client)
|
||||
}
|
||||
mws.mutex.RUnlock()
|
||||
|
||||
if len(clients) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
message := struct {
|
||||
Type string `json:"type"`
|
||||
Action string `json:"action"`
|
||||
NodeIP string `json:"nodeIp"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
}{
|
||||
Type: "node_discovery",
|
||||
Action: action,
|
||||
NodeIP: nodeIP,
|
||||
Timestamp: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
data, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
mws.logger.WithError(err).Error("Failed to marshal node discovery event")
|
||||
return
|
||||
}
|
||||
|
||||
// Send to all clients with write synchronization
|
||||
mws.writeMutex.Lock()
|
||||
defer mws.writeMutex.Unlock()
|
||||
|
||||
for _, client := range clients {
|
||||
client.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
||||
if err := client.WriteMessage(websocket.TextMessage, data); err != nil {
|
||||
mws.logger.WithError(err).Error("Failed to send node discovery event to client")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BroadcastFirmwareUploadStatus sends firmware upload status updates to all clients
|
||||
func (mws *MockWebSocketServer) BroadcastFirmwareUploadStatus(nodeIP, status, filename string, fileSize int) {
|
||||
mws.mutex.RLock()
|
||||
clients := make([]*websocket.Conn, 0, len(mws.clients))
|
||||
for client := range mws.clients {
|
||||
clients = append(clients, client)
|
||||
}
|
||||
mws.mutex.RUnlock()
|
||||
|
||||
if len(clients) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
message := struct {
|
||||
Type string `json:"type"`
|
||||
NodeIP string `json:"nodeIp"`
|
||||
Status string `json:"status"`
|
||||
Filename string `json:"filename"`
|
||||
FileSize int `json:"fileSize"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
}{
|
||||
Type: "firmware_upload_status",
|
||||
NodeIP: nodeIP,
|
||||
Status: status,
|
||||
Filename: filename,
|
||||
FileSize: fileSize,
|
||||
Timestamp: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
data, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
mws.logger.WithError(err).Error("Failed to marshal firmware upload status")
|
||||
return
|
||||
}
|
||||
|
||||
mws.logger.WithFields(log.Fields{
|
||||
"node_ip": nodeIP,
|
||||
"status": status,
|
||||
"clients": len(clients),
|
||||
}).Debug("Broadcasting mock firmware upload status")
|
||||
|
||||
// Send to all clients with write synchronization
|
||||
mws.writeMutex.Lock()
|
||||
defer mws.writeMutex.Unlock()
|
||||
|
||||
for _, client := range clients {
|
||||
client.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
||||
if err := client.WriteMessage(websocket.TextMessage, data); err != nil {
|
||||
mws.logger.WithError(err).Error("Failed to send firmware upload status to client")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BroadcastRolloutProgress sends rollout progress updates to all clients
|
||||
func (mws *MockWebSocketServer) BroadcastRolloutProgress(rolloutID, nodeIP, status string, current, total int) {
|
||||
mws.mutex.RLock()
|
||||
clients := make([]*websocket.Conn, 0, len(mws.clients))
|
||||
for client := range mws.clients {
|
||||
clients = append(clients, client)
|
||||
}
|
||||
mws.mutex.RUnlock()
|
||||
|
||||
if len(clients) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
message := struct {
|
||||
Type string `json:"type"`
|
||||
RolloutID string `json:"rolloutId"`
|
||||
NodeIP string `json:"nodeIp"`
|
||||
Status string `json:"status"`
|
||||
Current int `json:"current"`
|
||||
Total int `json:"total"`
|
||||
Progress int `json:"progress"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
}{
|
||||
Type: "rollout_progress",
|
||||
RolloutID: rolloutID,
|
||||
NodeIP: nodeIP,
|
||||
Status: status,
|
||||
Current: current,
|
||||
Total: total,
|
||||
Progress: calculateProgress(current, total, status),
|
||||
Timestamp: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
data, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
mws.logger.WithError(err).Error("Failed to marshal rollout progress")
|
||||
return
|
||||
}
|
||||
|
||||
mws.logger.WithFields(log.Fields{
|
||||
"rollout_id": rolloutID,
|
||||
"node_ip": nodeIP,
|
||||
"status": status,
|
||||
"progress": fmt.Sprintf("%d/%d", current, total),
|
||||
}).Debug("Broadcasting mock rollout progress")
|
||||
|
||||
// Send to all clients with write synchronization
|
||||
mws.writeMutex.Lock()
|
||||
defer mws.writeMutex.Unlock()
|
||||
|
||||
for _, client := range clients {
|
||||
client.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
||||
if err := client.WriteMessage(websocket.TextMessage, data); err != nil {
|
||||
mws.logger.WithError(err).Error("Failed to send rollout progress to client")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// calculateProgress calculates the correct progress percentage based on current status
|
||||
func calculateProgress(current, total int, status string) int {
|
||||
if total == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Base progress is based on completed nodes
|
||||
completedNodes := current - 1
|
||||
if status == "completed" {
|
||||
completedNodes = current
|
||||
}
|
||||
|
||||
// Calculate base progress (completed nodes / total nodes)
|
||||
baseProgress := float64(completedNodes) / float64(total) * 100
|
||||
|
||||
// If currently updating labels or uploading, add partial progress for the current node
|
||||
if status == "updating_labels" {
|
||||
nodeProgress := 100.0 / float64(total) * 0.25
|
||||
baseProgress += nodeProgress
|
||||
} else if status == "uploading" {
|
||||
nodeProgress := 100.0 / float64(total) * 0.5
|
||||
baseProgress += nodeProgress
|
||||
}
|
||||
|
||||
// Ensure we don't exceed 100%
|
||||
if baseProgress > 100 {
|
||||
baseProgress = 100
|
||||
}
|
||||
|
||||
return int(baseProgress)
|
||||
}
|
||||
|
||||
// GetClientCount returns the number of connected WebSocket clients
|
||||
func (mws *MockWebSocketServer) GetClientCount() int {
|
||||
mws.mutex.RLock()
|
||||
defer mws.mutex.RUnlock()
|
||||
return len(mws.clients)
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the WebSocket server
|
||||
func (mws *MockWebSocketServer) Shutdown(ctx context.Context) error {
|
||||
mws.shutdownOnce.Do(func() {
|
||||
mws.logger.Info("Shutting down mock WebSocket server")
|
||||
close(mws.shutdownChan)
|
||||
|
||||
mws.mutex.Lock()
|
||||
clients := make([]*websocket.Conn, 0, len(mws.clients))
|
||||
for client := range mws.clients {
|
||||
clients = append(clients, client)
|
||||
}
|
||||
mws.mutex.Unlock()
|
||||
|
||||
// Close all client connections
|
||||
for _, client := range clients {
|
||||
client.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, "Server shutting down"))
|
||||
client.Close()
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user