package websocket import ( "bytes" "common/net/socket" "context" "github.com/panjf2000/gnet/v2" "github.com/panjf2000/gnet/v2/pkg/logging" "sync" "time" ) // WSServer 实现GNet库接口 type WSServer struct { gnet.BuiltinEventEngine eng gnet.Engine i socket.ISocketServer logger logging.Logger // 日志 upgradeTimeout time.Duration // 升级超时时间 unUpgradeConn sync.Map } func NewWSServer(i socket.ISocketServer, logger logging.Logger, timeout time.Duration) *WSServer { if i == nil { return nil } return &WSServer{ i: i, logger: logger, upgradeTimeout: timeout, unUpgradeConn: sync.Map{}, } } func (s *WSServer) Run(logger logging.Logger, addr string, multiCore, reusePort, tick, lockOSThread, reuseAddr bool, processNum int) error { return gnet.Run( s, addr, gnet.WithMulticore(multiCore), gnet.WithNumEventLoop(processNum), gnet.WithReusePort(reusePort), gnet.WithTicker(tick), gnet.WithLogger(logger), gnet.WithLockOSThread(lockOSThread), gnet.WithReuseAddr(reuseAddr), ) } func (s *WSServer) Stop() error { return s.eng.Stop(context.Background()) } func (s *WSServer) OnBoot(eng gnet.Engine) gnet.Action { s.eng = eng return gnet.None } func (s *WSServer) OnOpen(c gnet.Conn) ([]byte, gnet.Action) { ws := &WSConn{ Conn: c, isUpgrade: false, openTime: time.Now().Unix(), buf: bytes.Buffer{}, logger: s.logger, wsMessageBuf: wsMessageBuf{ curHeader: nil, cachedBuf: bytes.Buffer{}, }, param: make(map[string]interface{}), remoteAddr: c.RemoteAddr().String(), } c.SetContext(ws) s.unUpgradeConn.Store(c.RemoteAddr().String(), ws) d, a := s.i.OnOpen(ws) return d, (gnet.Action)(a) } // OnClose fires when a connection has been closed. // The parameter err is the last known connection error. func (s *WSServer) OnClose(c gnet.Conn, err error) (action gnet.Action) { s.unUpgradeConn.Delete(c.RemoteAddr().String()) ws, ok := c.Context().(*WSConn) if ok { ws.isClose = true ws.logger.Warnf("connection close, err: %v", err) return gnet.Action(s.i.OnClose(ws, err)) } return } // OnTraffic fires when a local socket receives data from the peer. func (s *WSServer) OnTraffic(c gnet.Conn) (action gnet.Action) { tmp := c.Context() if tmp == nil { s.logger.Errorf("OnTraffic context nil: %v", c) action = gnet.Close return } ws, ok := tmp.(*WSConn) if !ok { ws.logger.Errorf("OnTraffic convert ws error: %v", tmp) action = gnet.Close return } action = ws.readBytesBuf(c) if action != gnet.None { return } if !ws.isUpgrade { var data []byte data, ok, action = ws.upgrade() if ok { s.unUpgradeConn.Delete(c.RemoteAddr().String()) s.i.OnHandShake(ws) if data != nil { err := ws.Conn.AsyncWrite(data, nil) if err != nil { ws.logger.Errorf("update ws write upgrade protocol error", err) action = gnet.Close } } } } else { msg, err := ws.readWsMessages() if err != nil { ws.logger.Errorf("read ws messages errors", err) return gnet.Close } if msg != nil { for _, m := range msg { if socket.OpCode(m.OpCode) == socket.OpPong { s.i.OnPong(ws) continue } if socket.OpCode(m.OpCode) == socket.OpClose { return gnet.Close } if socket.OpCode(m.OpCode) == socket.OpPing { continue } a := s.i.OnMessage(ws, m.Payload) if gnet.Action(a) != gnet.None { action = gnet.Action(a) } } } } return } // OnTick fires immediately after the engine starts and will fire again // following the duration specified by the delay return value. func (s *WSServer) OnTick() (delay time.Duration, action gnet.Action) { now := time.Now().Unix() delConn := make([]string, 0) s.unUpgradeConn.Range(func(key, value interface{}) bool { k, ok := key.(string) if !ok { return true } v, ok := value.(*WSConn) if !ok { return true } if now-v.openTime >= int64(s.upgradeTimeout.Seconds()) { delConn = append(delConn, k) } return true }) for _, k := range delConn { wsConn, _ := s.unUpgradeConn.LoadAndDelete(k) if wsConn == nil { continue } v, ok := wsConn.(*WSConn) if !ok { continue } if err := v.Close(); err != nil { v.logger.Errorf("upgrade ws time out close socket error: %v", err) } } d, a := s.i.OnTick() delay = d action = gnet.Action(a) return }