package websocket import ( "encoding/json" "fmt" "net/http" "sync" "time" "api-service/pkg/logger" "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/gorilla/websocket" ) // WebSocketMessage represents a message structure for WebSocket communication type WebSocketMessage struct { Type string `json:"type"` Data interface{} `json:"data"` Timestamp time.Time `json:"timestamp"` ClientID string `json:"client_id,omitempty"` } // Client represents a WebSocket client connection type Client struct { ID string Conn *websocket.Conn Send chan WebSocketMessage Hub *Hub UserID string Room string } // Hub manages WebSocket connections and broadcasting type Hub struct { clients map[*Client]bool broadcast chan WebSocketMessage register chan *Client unregister chan *Client rooms map[string]map[*Client]bool mu sync.RWMutex } // NewHub creates a new WebSocket hub func NewHub() *Hub { return &Hub{ clients: make(map[*Client]bool), broadcast: make(chan WebSocketMessage), register: make(chan *Client), unregister: make(chan *Client), rooms: make(map[string]map[*Client]bool), } } // Run starts the hub and handles client registration/deregistration func (h *Hub) Run() { for { select { case client := <-h.register: h.clients[client] = true if client.Room != "" { h.mu.Lock() if h.rooms[client.Room] == nil { h.rooms[client.Room] = make(map[*Client]bool) } h.rooms[client.Room][client] = true h.mu.Unlock() } logger.Info(fmt.Sprintf("Client %s connected", client.ID)) case client := <-h.unregister: if _, ok := h.clients[client]; ok { delete(h.clients, client) close(client.Send) if client.Room != "" { h.mu.Lock() if room, exists := h.rooms[client.Room]; exists { delete(room, client) if len(room) == 0 { delete(h.rooms, client.Room) } } h.mu.Unlock() } } logger.Info(fmt.Sprintf("Client %s disconnected", client.ID)) case message := <-h.broadcast: h.broadcastToClients(message) } } } // broadcastToClients sends a message to appropriate clients func (h *Hub) broadcastToClients(message WebSocketMessage) { if message.ClientID != "" { // Send to specific client for client := range h.clients { if client.ID == message.ClientID { select { case client.Send <- message: default: close(client.Send) delete(h.clients, client) } break } } } else if message.Type == "room" && message.Data.(map[string]interface{})["room"] != nil { // Send to room roomName := message.Data.(map[string]interface{})["room"].(string) h.mu.RLock() if room, exists := h.rooms[roomName]; exists { for client := range room { select { case client.Send <- message: default: close(client.Send) delete(h.clients, client) } } } h.mu.RUnlock() } else { // Broadcast to all clients for client := range h.clients { select { case client.Send <- message: default: close(client.Send) delete(h.clients, client) } } } } // WebSocketHandler handles WebSocket connections type WebSocketHandler struct { hub *Hub logger *logger.Logger upgrader websocket.Upgrader broadcaster *Broadcaster } // NewWebSocketHandler creates a new WebSocket handler func NewWebSocketHandler() *WebSocketHandler { hub := NewHub() go hub.Run() return &WebSocketHandler{ hub: hub, logger: logger.Default(), upgrader: websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { // Allow connections from any origin in development // In production, you should check the origin return true }, ReadBufferSize: 1024, WriteBufferSize: 1024, }, } } // HandleWebSocket handles WebSocket upgrade and connection func (h *WebSocketHandler) HandleWebSocket(c *gin.Context) { conn, err := h.upgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { h.logger.Error(fmt.Sprintf("Failed to upgrade connection: %v", err)) return } // Get user ID and room from query parameters or headers userID := c.Query("user_id") if userID == "" { userID = "anonymous" } room := c.Query("room") if room == "" { room = "default" } client := &Client{ ID: uuid.New().String(), Conn: conn, Send: make(chan WebSocketMessage, 256), Hub: h.hub, UserID: userID, Room: room, } client.Hub.register <- client // Send welcome message welcomeMsg := WebSocketMessage{ Type: "welcome", Data: map[string]string{"message": "Connected to WebSocket server"}, Timestamp: time.Now(), ClientID: client.ID, } select { case client.Send <- welcomeMsg: default: close(client.Send) return } // Start goroutines for reading and writing go client.writePump() go client.readPump() } // readPump reads messages from the WebSocket connection func (c *Client) readPump() { defer func() { c.Hub.unregister <- c c.Conn.Close() }() c.Conn.SetReadLimit(512) c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second)) c.Conn.SetPongHandler(func(string) error { c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second)) return nil }) for { _, message, err := c.Conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { logger.Error(fmt.Sprintf("WebSocket error: %v", err)) } break } // Parse incoming message var msg WebSocketMessage if err := json.Unmarshal(message, &msg); err != nil { logger.Error(fmt.Sprintf("Failed to parse message: %v", err)) continue } msg.Timestamp = time.Now() msg.ClientID = c.ID // Send response back to the client responseMsg := WebSocketMessage{ Type: "response", Data: map[string]interface{}{ "message": "Message received from server: " + fmt.Sprintf("%v", msg.Data), "original_type": msg.Type, }, Timestamp: time.Now(), ClientID: c.ID, } select { case c.Send <- responseMsg: default: close(c.Send) delete(c.Hub.clients, c) } // Broadcast the message c.Hub.broadcast <- msg } } // writePump writes messages to the WebSocket connection func (c *Client) writePump() { ticker := time.NewTicker(54 * time.Second) defer func() { ticker.Stop() c.Conn.Close() }() for { select { case message, ok := <-c.Send: c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) if !ok { c.Conn.WriteMessage(websocket.CloseMessage, []byte{}) return } if err := c.Conn.WriteJSON(message); err != nil { logger.Error(fmt.Sprintf("Failed to write message: %v", err)) return } case <-ticker.C: c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil { return } } } } // BroadcastMessage broadcasts a message to all connected clients func (h *WebSocketHandler) BroadcastMessage(messageType string, data interface{}) { msg := WebSocketMessage{ Type: messageType, Data: data, Timestamp: time.Now(), } h.hub.broadcast <- msg } // NotifyDataChange sends a notification message to all clients when data changes in the database func (h *WebSocketHandler) NotifyDataChange(data interface{}) { msg := WebSocketMessage{ Type: "data_change", Data: data, Timestamp: time.Now(), } h.hub.broadcast <- msg } // BroadcastToRoom broadcasts a message to clients in a specific room func (h *WebSocketHandler) BroadcastToRoom(room string, messageType string, data interface{}) { msg := WebSocketMessage{ Type: messageType, Data: map[string]interface{}{ "room": room, "data": data, }, Timestamp: time.Now(), } h.hub.broadcast <- msg } // SendToClient sends a message to a specific client func (h *WebSocketHandler) SendToClient(clientID string, messageType string, data interface{}) { msg := WebSocketMessage{ Type: messageType, Data: data, Timestamp: time.Now(), ClientID: clientID, } h.hub.broadcast <- msg } // GetConnectedClients returns the number of connected clients func (h *WebSocketHandler) GetConnectedClients() int { return len(h.hub.clients) } // GetRoomClients returns the number of clients in a specific room func (h *WebSocketHandler) GetRoomClients(room string) int { h.hub.mu.RLock() defer h.hub.mu.RUnlock() if roomClients, exists := h.hub.rooms[room]; exists { return len(roomClients) } return 0 }