package websocket import ( "bufio" "bytes" "errors" "github.com/gobwas/ws" "github.com/gobwas/ws/wsutil" "github.com/panjf2000/gnet/v2" "github.com/panjf2000/gnet/v2/pkg/logging" "io" "net/url" ) // WSConn 实现ISocketConn接口 type WSConn struct { gnet.Conn buf bytes.Buffer logger logging.Logger isUpgrade bool isClose bool param map[string]interface{} openTime int64 remoteAddr string wsMessageBuf } type wsMessageBuf struct { curHeader *ws.Header cachedBuf bytes.Buffer } type readWrite struct { io.Reader io.Writer } func (w *WSConn) readBytesBuf(c gnet.Conn) gnet.Action { size := c.InboundBuffered() if size <= 0 { return gnet.None } buf := make([]byte, size) read, err := c.Read(buf) if err != nil { if w.logger != nil { w.logger.Errorf("ws read bytes buf error", err) } return gnet.Close } if read < size { if w.logger != nil { w.logger.Errorf("read bytes len err! size: %d read: %d", size, read) } return gnet.Close } w.buf.Write(buf) return gnet.None } func (w *WSConn) upgrade() (data []byte, ok bool, action gnet.Action) { if w.isUpgrade { ok = true return } buf := &w.buf tmpReader := bytes.NewReader(buf.Bytes()) oldLen := tmpReader.Len() result := &bytes.Buffer{} tempWriter := bufio.NewWriter(result) var err error = nil up := ws.Upgrader{ OnRequest: w.OnRequest, } _, err = up.Upgrade(readWrite{tmpReader, tempWriter}) skipN := oldLen - tmpReader.Len() if err != nil { if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { //数据不完整 return } buf.Next(skipN) if w.logger != nil { w.logger.Errorf("ws upgrade error", err.Error()) } action = gnet.Close return } buf.Next(skipN) if w.logger != nil { w.logger.Infof("ws upgrade success conn upgrade websocket protocol!") } _ = tempWriter.Flush() data = result.Bytes() ok = true w.isUpgrade = true return } func (w *WSConn) readWsMessages() (messages []wsutil.Message, err error) { in := &w.buf //messages, err = wsutil.ReadClientMessage(in, messages) //return for { if w.curHeader == nil { if in.Len() < ws.MinHeaderSize { //头长度至少是2 return } var head ws.Header //有可能不完整,构建新的 reader 读取 head 读取成功才实际对 in 进行读操作 tmpReader := bytes.NewReader(in.Bytes()) oldLen := tmpReader.Len() head, err = ws.ReadHeader(tmpReader) skipN := oldLen - tmpReader.Len() if err != nil { if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { //数据不完整 return messages, nil } in.Next(skipN) return nil, err } _, err = io.CopyN(&w.cachedBuf, in, int64(skipN)) if err != nil { return } //in.Next(skipN) w.curHeader = &head //err = ws.WriteHeader(&msgBuf.cachedBuf, head) //if err != nil { // return nil, err //} } dataLen := (int)(w.curHeader.Length) if dataLen > 0 { if in.Len() >= dataLen { _, err = io.CopyN(&w.cachedBuf, in, int64(dataLen)) if err != nil { return } } else { //数据不完整 if w.logger != nil { w.logger.Debugf("ws read ws message incomplete data", in.Len(), dataLen) } return } } if w.curHeader.Fin { //当前 header 已经是一个完整消息 messages, err = wsutil.ReadClientMessage(&w.cachedBuf, messages) if err != nil { return nil, err } w.cachedBuf.Reset() } w.curHeader = nil } } func (w *WSConn) OnRequest(u []byte) error { parsedURL, err := url.Parse(string(u)) if err != nil { return err } for key, value := range parsedURL.Query() { w.SetParam(key, value[0]) } return nil } func (w *WSConn) GetParam(key string) interface{} { return w.param[key] } func (w *WSConn) SetParam(key string, values interface{}) { w.param[key] = values } func (w *WSConn) RemoteAddr() string { return w.remoteAddr } func (w *WSConn) Write(data []byte) error { return w.write(data, ws.OpBinary) } func (w *WSConn) Ping() (err error) { return w.write(make([]byte, 0), ws.OpPing) } func (w *WSConn) Close() (err error) { defer func(Conn gnet.Conn) { err = Conn.Close() }(w.Conn) return w.write(make([]byte, 0), ws.OpClose) } func (w *WSConn) IsClose() bool { return w.isClose } func (w *WSConn) write(data []byte, opCode ws.OpCode) error { if w.isClose { return errors.New("connection has close") } buf := bytes.Buffer{} if err := wsutil.WriteServerMessage(&buf, opCode, data); err != nil { return err } return w.Conn.AsyncWrite(buf.Bytes(), nil) }