175 lines
4.0 KiB
Go
175 lines
4.0 KiB
Go
package network
|
|
|
|
import (
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"google.golang.org/protobuf/proto"
|
|
|
|
"a301_game_server/pkg/logger"
|
|
)
|
|
|
|
// ConnState represents the lifecycle state of a connection.
|
|
type ConnState int32
|
|
|
|
const (
|
|
ConnStateActive ConnState = iota
|
|
ConnStateClosed
|
|
)
|
|
|
|
// Connection wraps a WebSocket connection with send buffering and lifecycle management.
|
|
type Connection struct {
|
|
id uint64
|
|
ws *websocket.Conn
|
|
sendCh chan []byte
|
|
handler PacketHandler
|
|
state atomic.Int32
|
|
closeOnce sync.Once
|
|
|
|
maxMessageSize int64
|
|
heartbeatInterval time.Duration
|
|
heartbeatTimeout time.Duration
|
|
}
|
|
|
|
// PacketHandler processes incoming packets from a connection.
|
|
type PacketHandler interface {
|
|
OnPacket(conn *Connection, pkt *Packet)
|
|
OnDisconnect(conn *Connection)
|
|
}
|
|
|
|
// NewConnection creates a new Connection wrapping the given WebSocket.
|
|
func NewConnection(id uint64, ws *websocket.Conn, handler PacketHandler, sendChSize int, maxMsgSize int64, hbInterval, hbTimeout time.Duration) *Connection {
|
|
c := &Connection{
|
|
id: id,
|
|
ws: ws,
|
|
sendCh: make(chan []byte, sendChSize),
|
|
handler: handler,
|
|
maxMessageSize: maxMsgSize,
|
|
heartbeatInterval: hbInterval,
|
|
heartbeatTimeout: hbTimeout,
|
|
}
|
|
c.state.Store(int32(ConnStateActive))
|
|
return c
|
|
}
|
|
|
|
// ID returns the connection's unique identifier.
|
|
func (c *Connection) ID() uint64 { return c.id }
|
|
|
|
// Start launches the read and write goroutines.
|
|
func (c *Connection) Start() {
|
|
go c.readLoop()
|
|
go c.writeLoop()
|
|
}
|
|
|
|
// Send encodes and queues a message for sending. Non-blocking: drops if buffer is full.
|
|
func (c *Connection) Send(msgType uint16, msg proto.Message) {
|
|
if c.IsClosed() {
|
|
return
|
|
}
|
|
|
|
data, err := Encode(msgType, msg)
|
|
if err != nil {
|
|
logger.Error("encode failed", "connID", c.id, "msgType", msgType, "error", err)
|
|
return
|
|
}
|
|
|
|
select {
|
|
case c.sendCh <- data:
|
|
default:
|
|
logger.Warn("send buffer full, dropping message", "connID", c.id, "msgType", msgType)
|
|
}
|
|
}
|
|
|
|
// SendRaw queues pre-encoded data for sending. Non-blocking.
|
|
func (c *Connection) SendRaw(data []byte) {
|
|
if c.IsClosed() {
|
|
return
|
|
}
|
|
select {
|
|
case c.sendCh <- data:
|
|
default:
|
|
logger.Warn("send buffer full, dropping raw message", "connID", c.id)
|
|
}
|
|
}
|
|
|
|
// Close terminates the connection.
|
|
func (c *Connection) Close() {
|
|
c.closeOnce.Do(func() {
|
|
c.state.Store(int32(ConnStateClosed))
|
|
close(c.sendCh)
|
|
_ = c.ws.Close()
|
|
})
|
|
}
|
|
|
|
// IsClosed returns true if the connection has been closed.
|
|
func (c *Connection) IsClosed() bool {
|
|
return ConnState(c.state.Load()) == ConnStateClosed
|
|
}
|
|
|
|
func (c *Connection) readLoop() {
|
|
defer func() {
|
|
c.handler.OnDisconnect(c)
|
|
c.Close()
|
|
}()
|
|
|
|
c.ws.SetReadLimit(c.maxMessageSize)
|
|
_ = c.ws.SetReadDeadline(time.Now().Add(c.heartbeatTimeout))
|
|
|
|
c.ws.SetPongHandler(func(string) error {
|
|
_ = c.ws.SetReadDeadline(time.Now().Add(c.heartbeatTimeout))
|
|
return nil
|
|
})
|
|
|
|
for {
|
|
msgType, data, err := c.ws.ReadMessage()
|
|
if err != nil {
|
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) {
|
|
logger.Debug("read error", "connID", c.id, "error", err)
|
|
}
|
|
return
|
|
}
|
|
|
|
if msgType != websocket.BinaryMessage {
|
|
continue
|
|
}
|
|
|
|
pkt, err := Decode(data)
|
|
if err != nil {
|
|
logger.Warn("decode error", "connID", c.id, "error", err)
|
|
continue
|
|
}
|
|
|
|
c.handler.OnPacket(c, pkt)
|
|
}
|
|
}
|
|
|
|
func (c *Connection) writeLoop() {
|
|
ticker := time.NewTicker(c.heartbeatInterval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case data, ok := <-c.sendCh:
|
|
if !ok {
|
|
_ = c.ws.WriteMessage(websocket.CloseMessage,
|
|
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
|
return
|
|
}
|
|
|
|
_ = c.ws.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
|
if err := c.ws.WriteMessage(websocket.BinaryMessage, data); err != nil {
|
|
logger.Debug("write error", "connID", c.id, "error", err)
|
|
return
|
|
}
|
|
|
|
case <-ticker.C:
|
|
_ = c.ws.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
|
if err := c.ws.WriteMessage(websocket.PingMessage, nil); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|