package network import ( "sync" "sync/atomic" "time" "github.com/gorilla/websocket" "google.golang.org/protobuf/proto" "a301_game_server/pkg/logger" ) // ConnState represents the lifecycle state of a connection. type ConnState int32 const ( ConnStateActive ConnState = iota ConnStateClosed ) // Connection wraps a WebSocket connection with send buffering and lifecycle management. type Connection struct { id uint64 ws *websocket.Conn sendCh chan []byte handler PacketHandler state atomic.Int32 closeOnce sync.Once maxMessageSize int64 heartbeatInterval time.Duration heartbeatTimeout time.Duration } // PacketHandler processes incoming packets from a connection. type PacketHandler interface { OnPacket(conn *Connection, pkt *Packet) OnDisconnect(conn *Connection) } // NewConnection creates a new Connection wrapping the given WebSocket. func NewConnection(id uint64, ws *websocket.Conn, handler PacketHandler, sendChSize int, maxMsgSize int64, hbInterval, hbTimeout time.Duration) *Connection { c := &Connection{ id: id, ws: ws, sendCh: make(chan []byte, sendChSize), handler: handler, maxMessageSize: maxMsgSize, heartbeatInterval: hbInterval, heartbeatTimeout: hbTimeout, } c.state.Store(int32(ConnStateActive)) return c } // ID returns the connection's unique identifier. func (c *Connection) ID() uint64 { return c.id } // Start launches the read and write goroutines. func (c *Connection) Start() { go c.readLoop() go c.writeLoop() } // Send encodes and queues a message for sending. Non-blocking: drops if buffer is full. func (c *Connection) Send(msgType uint16, msg proto.Message) { if c.IsClosed() { return } data, err := Encode(msgType, msg) if err != nil { logger.Error("encode failed", "connID", c.id, "msgType", msgType, "error", err) return } select { case c.sendCh <- data: default: logger.Warn("send buffer full, dropping message", "connID", c.id, "msgType", msgType) } } // SendRaw queues pre-encoded data for sending. Non-blocking. func (c *Connection) SendRaw(data []byte) { if c.IsClosed() { return } select { case c.sendCh <- data: default: logger.Warn("send buffer full, dropping raw message", "connID", c.id) } } // Close terminates the connection. func (c *Connection) Close() { c.closeOnce.Do(func() { c.state.Store(int32(ConnStateClosed)) close(c.sendCh) _ = c.ws.Close() }) } // IsClosed returns true if the connection has been closed. func (c *Connection) IsClosed() bool { return ConnState(c.state.Load()) == ConnStateClosed } func (c *Connection) readLoop() { defer func() { c.handler.OnDisconnect(c) c.Close() }() c.ws.SetReadLimit(c.maxMessageSize) _ = c.ws.SetReadDeadline(time.Now().Add(c.heartbeatTimeout)) c.ws.SetPongHandler(func(string) error { _ = c.ws.SetReadDeadline(time.Now().Add(c.heartbeatTimeout)) return nil }) for { msgType, data, err := c.ws.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { logger.Debug("read error", "connID", c.id, "error", err) } return } if msgType != websocket.BinaryMessage { continue } pkt, err := Decode(data) if err != nil { logger.Warn("decode error", "connID", c.id, "error", err) continue } c.handler.OnPacket(c, pkt) } } func (c *Connection) writeLoop() { ticker := time.NewTicker(c.heartbeatInterval) defer ticker.Stop() for { select { case data, ok := <-c.sendCh: if !ok { _ = c.ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) return } _ = c.ws.SetWriteDeadline(time.Now().Add(10 * time.Second)) if err := c.ws.WriteMessage(websocket.BinaryMessage, data); err != nil { logger.Debug("write error", "connID", c.id, "error", err) return } case <-ticker.C: _ = c.ws.SetWriteDeadline(time.Now().Add(10 * time.Second)) if err := c.ws.WriteMessage(websocket.PingMessage, nil); err != nil { return } } } }