367 lines
8.3 KiB
Go
367 lines
8.3 KiB
Go
package websocket
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"api-service/pkg/logger"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/google/uuid"
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
// WebSocketMessage represents a message structure for WebSocket communication
|
|
type WebSocketMessage struct {
|
|
Type string `json:"type"`
|
|
Data interface{} `json:"data"`
|
|
Timestamp time.Time `json:"timestamp"`
|
|
ClientID string `json:"client_id,omitempty"`
|
|
}
|
|
|
|
// Client represents a WebSocket client connection
|
|
type Client struct {
|
|
ID string
|
|
Conn *websocket.Conn
|
|
Send chan WebSocketMessage
|
|
Hub *Hub
|
|
UserID string
|
|
Room string
|
|
}
|
|
|
|
// Hub manages WebSocket connections and broadcasting
|
|
type Hub struct {
|
|
clients map[*Client]bool
|
|
broadcast chan WebSocketMessage
|
|
register chan *Client
|
|
unregister chan *Client
|
|
rooms map[string]map[*Client]bool
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
// NewHub creates a new WebSocket hub
|
|
func NewHub() *Hub {
|
|
return &Hub{
|
|
clients: make(map[*Client]bool),
|
|
broadcast: make(chan WebSocketMessage),
|
|
register: make(chan *Client),
|
|
unregister: make(chan *Client),
|
|
rooms: make(map[string]map[*Client]bool),
|
|
}
|
|
}
|
|
|
|
// Run starts the hub and handles client registration/deregistration
|
|
func (h *Hub) Run() {
|
|
for {
|
|
select {
|
|
case client := <-h.register:
|
|
h.clients[client] = true
|
|
if client.Room != "" {
|
|
h.mu.Lock()
|
|
if h.rooms[client.Room] == nil {
|
|
h.rooms[client.Room] = make(map[*Client]bool)
|
|
}
|
|
h.rooms[client.Room][client] = true
|
|
h.mu.Unlock()
|
|
}
|
|
logger.Info(fmt.Sprintf("Client %s connected", client.ID))
|
|
|
|
case client := <-h.unregister:
|
|
if _, ok := h.clients[client]; ok {
|
|
delete(h.clients, client)
|
|
close(client.Send)
|
|
|
|
if client.Room != "" {
|
|
h.mu.Lock()
|
|
if room, exists := h.rooms[client.Room]; exists {
|
|
delete(room, client)
|
|
if len(room) == 0 {
|
|
delete(h.rooms, client.Room)
|
|
}
|
|
}
|
|
h.mu.Unlock()
|
|
}
|
|
}
|
|
logger.Info(fmt.Sprintf("Client %s disconnected", client.ID))
|
|
|
|
case message := <-h.broadcast:
|
|
h.broadcastToClients(message)
|
|
}
|
|
}
|
|
}
|
|
|
|
// broadcastToClients sends a message to appropriate clients
|
|
func (h *Hub) broadcastToClients(message WebSocketMessage) {
|
|
if message.ClientID != "" {
|
|
// Send to specific client
|
|
for client := range h.clients {
|
|
if client.ID == message.ClientID {
|
|
select {
|
|
case client.Send <- message:
|
|
default:
|
|
close(client.Send)
|
|
delete(h.clients, client)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
} else if message.Type == "room" && message.Data.(map[string]interface{})["room"] != nil {
|
|
// Send to room
|
|
roomName := message.Data.(map[string]interface{})["room"].(string)
|
|
h.mu.RLock()
|
|
if room, exists := h.rooms[roomName]; exists {
|
|
for client := range room {
|
|
select {
|
|
case client.Send <- message:
|
|
default:
|
|
close(client.Send)
|
|
delete(h.clients, client)
|
|
}
|
|
}
|
|
}
|
|
h.mu.RUnlock()
|
|
} else {
|
|
// Broadcast to all clients
|
|
for client := range h.clients {
|
|
select {
|
|
case client.Send <- message:
|
|
default:
|
|
close(client.Send)
|
|
delete(h.clients, client)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// WebSocketHandler handles WebSocket connections
|
|
type WebSocketHandler struct {
|
|
hub *Hub
|
|
logger *logger.Logger
|
|
upgrader websocket.Upgrader
|
|
broadcaster *Broadcaster
|
|
}
|
|
|
|
// NewWebSocketHandler creates a new WebSocket handler
|
|
func NewWebSocketHandler() *WebSocketHandler {
|
|
hub := NewHub()
|
|
go hub.Run()
|
|
|
|
return &WebSocketHandler{
|
|
hub: hub,
|
|
logger: logger.Default(),
|
|
upgrader: websocket.Upgrader{
|
|
CheckOrigin: func(r *http.Request) bool {
|
|
// Allow connections from any origin in development
|
|
// In production, you should check the origin
|
|
return true
|
|
},
|
|
ReadBufferSize: 1024,
|
|
WriteBufferSize: 1024,
|
|
},
|
|
}
|
|
}
|
|
|
|
// HandleWebSocket handles WebSocket upgrade and connection
|
|
func (h *WebSocketHandler) HandleWebSocket(c *gin.Context) {
|
|
conn, err := h.upgrader.Upgrade(c.Writer, c.Request, nil)
|
|
if err != nil {
|
|
h.logger.Error(fmt.Sprintf("Failed to upgrade connection: %v", err))
|
|
return
|
|
}
|
|
|
|
// Get user ID and room from query parameters or headers
|
|
userID := c.Query("user_id")
|
|
if userID == "" {
|
|
userID = "anonymous"
|
|
}
|
|
|
|
room := c.Query("room")
|
|
if room == "" {
|
|
room = "default"
|
|
}
|
|
|
|
client := &Client{
|
|
ID: uuid.New().String(),
|
|
Conn: conn,
|
|
Send: make(chan WebSocketMessage, 256),
|
|
Hub: h.hub,
|
|
UserID: userID,
|
|
Room: room,
|
|
}
|
|
|
|
client.Hub.register <- client
|
|
|
|
// Send welcome message
|
|
welcomeMsg := WebSocketMessage{
|
|
Type: "welcome",
|
|
Data: map[string]string{"message": "Connected to WebSocket server"},
|
|
Timestamp: time.Now(),
|
|
ClientID: client.ID,
|
|
}
|
|
|
|
select {
|
|
case client.Send <- welcomeMsg:
|
|
default:
|
|
close(client.Send)
|
|
return
|
|
}
|
|
|
|
// Start goroutines for reading and writing
|
|
go client.writePump()
|
|
go client.readPump()
|
|
}
|
|
|
|
// readPump reads messages from the WebSocket connection
|
|
func (c *Client) readPump() {
|
|
defer func() {
|
|
c.Hub.unregister <- c
|
|
c.Conn.Close()
|
|
}()
|
|
|
|
c.Conn.SetReadLimit(512)
|
|
c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
|
c.Conn.SetPongHandler(func(string) error {
|
|
c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
|
return nil
|
|
})
|
|
|
|
for {
|
|
_, message, err := c.Conn.ReadMessage()
|
|
if err != nil {
|
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
|
logger.Error(fmt.Sprintf("WebSocket error: %v", err))
|
|
}
|
|
break
|
|
}
|
|
|
|
// Parse incoming message
|
|
var msg WebSocketMessage
|
|
if err := json.Unmarshal(message, &msg); err != nil {
|
|
logger.Error(fmt.Sprintf("Failed to parse message: %v", err))
|
|
continue
|
|
}
|
|
|
|
msg.Timestamp = time.Now()
|
|
msg.ClientID = c.ID
|
|
|
|
// Send response back to the client
|
|
responseMsg := WebSocketMessage{
|
|
Type: "response",
|
|
Data: map[string]interface{}{
|
|
"message": "Message received from server: " + fmt.Sprintf("%v", msg.Data),
|
|
"original_type": msg.Type,
|
|
},
|
|
Timestamp: time.Now(),
|
|
ClientID: c.ID,
|
|
}
|
|
|
|
select {
|
|
case c.Send <- responseMsg:
|
|
default:
|
|
close(c.Send)
|
|
delete(c.Hub.clients, c)
|
|
}
|
|
|
|
// Broadcast the message
|
|
c.Hub.broadcast <- msg
|
|
}
|
|
}
|
|
|
|
// writePump writes messages to the WebSocket connection
|
|
func (c *Client) writePump() {
|
|
ticker := time.NewTicker(54 * time.Second)
|
|
defer func() {
|
|
ticker.Stop()
|
|
c.Conn.Close()
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case message, ok := <-c.Send:
|
|
c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
|
if !ok {
|
|
c.Conn.WriteMessage(websocket.CloseMessage, []byte{})
|
|
return
|
|
}
|
|
|
|
if err := c.Conn.WriteJSON(message); err != nil {
|
|
logger.Error(fmt.Sprintf("Failed to write message: %v", err))
|
|
return
|
|
}
|
|
|
|
case <-ticker.C:
|
|
c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
|
if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// BroadcastMessage broadcasts a message to all connected clients
|
|
func (h *WebSocketHandler) BroadcastMessage(messageType string, data interface{}) {
|
|
msg := WebSocketMessage{
|
|
Type: messageType,
|
|
Data: data,
|
|
Timestamp: time.Now(),
|
|
}
|
|
|
|
h.hub.broadcast <- msg
|
|
}
|
|
|
|
// NotifyDataChange sends a notification message to all clients when data changes in the database
|
|
func (h *WebSocketHandler) NotifyDataChange(data interface{}) {
|
|
msg := WebSocketMessage{
|
|
Type: "data_change",
|
|
Data: data,
|
|
Timestamp: time.Now(),
|
|
}
|
|
|
|
h.hub.broadcast <- msg
|
|
}
|
|
|
|
// BroadcastToRoom broadcasts a message to clients in a specific room
|
|
func (h *WebSocketHandler) BroadcastToRoom(room string, messageType string, data interface{}) {
|
|
msg := WebSocketMessage{
|
|
Type: messageType,
|
|
Data: map[string]interface{}{
|
|
"room": room,
|
|
"data": data,
|
|
},
|
|
Timestamp: time.Now(),
|
|
}
|
|
|
|
h.hub.broadcast <- msg
|
|
}
|
|
|
|
// SendToClient sends a message to a specific client
|
|
func (h *WebSocketHandler) SendToClient(clientID string, messageType string, data interface{}) {
|
|
msg := WebSocketMessage{
|
|
Type: messageType,
|
|
Data: data,
|
|
Timestamp: time.Now(),
|
|
ClientID: clientID,
|
|
}
|
|
|
|
h.hub.broadcast <- msg
|
|
}
|
|
|
|
// GetConnectedClients returns the number of connected clients
|
|
func (h *WebSocketHandler) GetConnectedClients() int {
|
|
return len(h.hub.clients)
|
|
}
|
|
|
|
// GetRoomClients returns the number of clients in a specific room
|
|
func (h *WebSocketHandler) GetRoomClients(room string) int {
|
|
h.hub.mu.RLock()
|
|
defer h.hub.mu.RUnlock()
|
|
|
|
if roomClients, exists := h.hub.rooms[room]; exists {
|
|
return len(roomClients)
|
|
}
|
|
return 0
|
|
}
|