diff --git a/app/app.go b/app/app.go index 5b04f71..7d953c2 100644 --- a/app/app.go +++ b/app/app.go @@ -25,6 +25,7 @@ func (p *Program) Init(_ svc.Environment) error { p.moduleList = append(p.moduleList, &ModuleWebServer{}) p.moduleList = append(p.moduleList, &ModuleWebsocketServer{}) p.moduleList = append(p.moduleList, &ModuleGrpcServer{}) + p.moduleList = append(p.moduleList, &ModuleLoginQueue{}) for _, module := range p.moduleList { if err := module.init(); err != nil { diff --git a/app/login_queue.go b/app/login_queue.go new file mode 100644 index 0000000..22559df --- /dev/null +++ b/app/login_queue.go @@ -0,0 +1,29 @@ +package app + +import ( + "gateway/internal/global" + "gateway/internal/handler/ws_handler/login" + "runtime" +) + +// ModuleLoginQueue 登录队列模块 +type ModuleLoginQueue struct { + login *login.Login + queueUp *login.QueueUp +} + +func (m *ModuleLoginQueue) init() error { + m.login = login.NewLoginQueue(global.MaxQueueUpSize) + m.queueUp = login.NewQueueUp(global.MaxQueueUpSize) + return nil +} + +func (m *ModuleLoginQueue) start() error { + m.login.Start(runtime.NumCPU()) + return nil +} + +func (m *ModuleLoginQueue) stop() error { + m.login.Stop() + return nil +} diff --git a/config/config.go b/config/config.go index 2e4b389..e5167f2 100644 --- a/config/config.go +++ b/config/config.go @@ -2,16 +2,7 @@ package config import "common/config" -const ( - path = "./config" - KeyUserAccessToken = "user:access:%v" - KeyUserRefreshToken = "user:refresh:%v" -) - -// PublicPaths 不需要鉴权的接口,硬编码注册 -var PublicPaths = []string{ - "/user/info", -} +const path = "./config" type Config struct { App *config.AppConfig `yaml:"app"` diff --git a/go.mod b/go.mod index c15e3d3..7b9486d 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/judwhite/go-svc v1.2.1 github.com/panjf2000/gnet/v2 v2.9.7 github.com/prometheus/client_golang v1.20.5 + github.com/redis/go-redis/v9 v9.10.0 github.com/stretchr/testify v1.11.1 go.uber.org/zap v1.27.0 google.golang.org/grpc v1.71.1 @@ -66,7 +67,6 @@ require ( github.com/prometheus/procfs v0.15.1 // indirect github.com/quic-go/qpack v0.5.1 // indirect github.com/quic-go/quic-go v0.54.0 // indirect - github.com/redis/go-redis/v9 v9.10.0 // indirect github.com/sagikazarmark/locafero v0.11.0 // indirect github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect github.com/spf13/afero v1.15.0 // indirect diff --git a/internal/global/global.go b/internal/global/global.go new file mode 100644 index 0000000..54e86ff --- /dev/null +++ b/internal/global/global.go @@ -0,0 +1,19 @@ +package global + +const ( + KeyGatewayAccessToken = "gateway:token:access:%v" + KeyGatewayRefreshToken = "gateway:token:refresh:%v" + + KeyGatewayInfo = "gateway:info:%v" + HFieldInfoGatewaySID = "gateway_sid" +) + +const ( + MaxOnlineSize = 100 // 最大在线人数 + MaxQueueUpSize = 100 // 最大排队人数 +) + +// PublicPaths 不需要鉴权的接口,硬编码注册 +var PublicPaths = []string{ + "/user/info", +} diff --git a/internal/grpc_server/server/server.go b/internal/grpc_server/server/server.go index 89dc7a2..f908800 100644 --- a/internal/grpc_server/server/server.go +++ b/internal/grpc_server/server/server.go @@ -4,7 +4,8 @@ import ( "common/log" "common/proto/sc/sc_pb" "common/proto/ss/grpc_pb" - "gateway/internal/handler/ws_handler" + "context" + "gateway/internal/handler/ws_handler/client" "google.golang.org/protobuf/proto" "sync" ) @@ -38,16 +39,16 @@ func (s *Server) ToClient(server grpc_pb.Gateway_ToClientServer) error { log.Errorf("ToClient proto.Marshal error: %v", err) continue } - for _, client := range ws_handler.UserMgr.GetAll() { - client.WriteBytesPreMarshal(data) + for _, cli := range client.UserMgr.GetAll() { + cli.WriteBytesPreMarshal(data) } //for _, client := range ws_handler.UserMgr.GetAll() { // client.WriteBytes(sc_pb.MessageID(args.MessageID), args.Payload) //} } else { - if client := ws_handler.UserMgr.GetByUSN(args.USN); client != nil { - client.WriteBytes(sc_pb.MessageID(args.MessageID), args.Payload) + if cli := client.UserMgr.GetByUSN(args.USN); cli != nil { + cli.WriteBytes(sc_pb.MessageID(args.MessageID), args.Payload) } } } @@ -56,3 +57,7 @@ func (s *Server) ToClient(server grpc_pb.Gateway_ToClientServer) error { wg.Wait() return server.SendAndClose(&grpc_pb.ToClientResp{}) } + +func (s *Server) KickUser(ctx context.Context, req *grpc_pb.KickUserReq) (*grpc_pb.KickUserResp, error) { + return &grpc_pb.KickUserResp{}, nil +} diff --git a/internal/grpc_server/server/server_init.go b/internal/grpc_server/server/server_init.go index c677208..5f51097 100644 --- a/internal/grpc_server/server/server_init.go +++ b/internal/grpc_server/server/server_init.go @@ -4,7 +4,7 @@ import ( "common/discover/common" "common/net/grpc/service" "common/proto/ss/grpc_pb" - "gateway/internal/handler/ws_handler" + "gateway/internal/handler/ws_handler/client" "google.golang.org/grpc" ) @@ -26,7 +26,7 @@ func NewServer(ttl int64) *Server { } func (s *Server) OnInit(serve *grpc.Server) { - ws_handler.GatewaySID = s.SID + client.GatewaySID = s.SID grpc_pb.RegisterGatewayServer(serve, s) } diff --git a/internal/grpc_server/stream_client/scene.go b/internal/grpc_server/stream_client/scene.go index 90e6ac2..3db503a 100644 --- a/internal/grpc_server/stream_client/scene.go +++ b/internal/grpc_server/stream_client/scene.go @@ -6,60 +6,79 @@ import ( "context" "google.golang.org/grpc" "google.golang.org/protobuf/proto" + "strconv" + "sync" ) -var sceneServerM map[int64]map[SceneFun]grpc.ClientStream // map[sid]map[方法名]流连接 - type SceneFun int const ( FunAction SceneFun = iota ) -func init() { - sceneServerM = make(map[int64]map[SceneFun]grpc.ClientStream) +var sceneServer sync.Map // map[string]*sceneStream + +type sceneStream struct { + mu sync.Mutex + stream grpc.ClientStream } -func findSceneBySID(sid int64, fun SceneFun) (grpc.ClientStream, error) { - g := sceneServerM[sid] - if g == nil { - g = make(map[SceneFun]grpc.ClientStream) - sceneServerM[sid] = g +func findSceneBySID(sid int64, fun SceneFun) (*sceneStream, error) { + key := sceneKey(sid, fun) + + if v, ok := sceneServer.Load(key); ok { + return v.(*sceneStream), nil } - sceneLink := g[fun] - if sceneLink == nil { - sceneClient, err := service.SceneNewClient(sid) - if err != nil { - log.Errorf("cannot find sceneClient: %v", err) - return nil, err - } - var link grpc.ClientStream - switch fun { - case FunAction: - link, err = sceneClient.Action(context.Background()) - } - if err != nil { - log.Errorf("findSceneBySID %v err: %v, sid: %v", fun, err, sid) - return nil, err - } - g[fun] = link - sceneLink = link + + client, err := service.SceneNewClient(sid) + if err != nil { + log.Errorf("findSceneBySID cannot find client: %v", err) + return nil, err } - return sceneLink, nil + var stream grpc.ClientStream + switch fun { + case FunAction: + stream, err = client.Action(context.Background()) + } + if err != nil { + log.Errorf("findSceneBySID %v err: %v, sid: %v", fun, err, sid) + return nil, err + } + + ss := &sceneStream{stream: stream} + if actual, loaded := sceneServer.LoadOrStore(key, ss); loaded { + go func() { _ = stream.CloseSend() }() + return actual.(*sceneStream), nil + } + + return ss, nil } func SendMessageToScene(sid int64, fun SceneFun, msg proto.Message, re ...bool) error { - stream, err := findSceneBySID(sid, fun) + ss, err := findSceneBySID(sid, fun) if err != nil { return err } - if err = stream.SendMsg(msg); err != nil { + + ss.mu.Lock() + err = ss.stream.SendMsg(msg) + ss.mu.Unlock() + + if err != nil { + key := sceneKey(sid, fun) + if v, ok := sceneServer.Load(key); ok && v == ss { + sceneServer.Delete(key) + _ = ss.stream.CloseSend() + } + // 如果没有标识本次是重试的,就重试一次(默认重试) if re == nil || !re[0] { - _ = stream.CloseSend() - delete(sceneServerM[sid], fun) return SendMessageToScene(sid, fun, msg, true) } return err } return nil } + +func sceneKey(sid int64, fun SceneFun) string { + return strconv.FormatInt(sid, 10) + "-" + strconv.Itoa(int(fun)) +} diff --git a/internal/grpc_server/stream_client/scene_test.go b/internal/grpc_server/stream_client/scene_test.go new file mode 100644 index 0000000..136cf41 --- /dev/null +++ b/internal/grpc_server/stream_client/scene_test.go @@ -0,0 +1,20 @@ +package stream_client + +import ( + "gateway/internal/testutil" + "github.com/stretchr/testify/suite" + "testing" +) + +type SceneTestSuite struct { + testutil.TestSuite +} + +func (ts *SceneTestSuite) TestSceneKey() { + r := sceneKey(1122, FunAction) + ts.Assert().Equal("1122-0", r) +} + +func TestLoginTestSuite(t *testing.T) { + suite.Run(t, &SceneTestSuite{}) +} diff --git a/internal/handler/http_handler/login.go b/internal/handler/http_handler/login.go index 28f1f4d..cd94807 100644 --- a/internal/handler/http_handler/login.go +++ b/internal/handler/http_handler/login.go @@ -10,6 +10,7 @@ import ( "context" "fmt" "gateway/config" + "gateway/internal/global" "github.com/gin-gonic/gin" "time" ) @@ -79,7 +80,7 @@ func RefreshToken(c *gin.Context) { http_resp.JsonOK(c, http_resp.Error(http_resp.TokenInvalid)) return } - if redis.GetClient().Get(c, fmt.Sprintf(config.KeyUserRefreshToken, claims.USN)).Val() != req.RefreshToken { + if redis.GetClient().Get(c, fmt.Sprintf(global.KeyGatewayRefreshToken, claims.USN)).Val() != req.RefreshToken { http_resp.JsonOK(c, http_resp.Error(http_resp.TokenInvalid)) return } @@ -97,11 +98,11 @@ func RefreshToken(c *gin.Context) { } func genToken(ctx context.Context, usn int64) (string, string, error) { - at, err := genTokenOne(ctx, config.KeyUserAccessToken, usn, 2*time.Hour) + at, err := genTokenOne(ctx, global.KeyGatewayAccessToken, usn, 2*time.Hour) if err != nil { return "", "", err } - rt, err := genTokenOne(ctx, config.KeyUserRefreshToken, usn, 3*24*time.Hour) + rt, err := genTokenOne(ctx, global.KeyGatewayRefreshToken, usn, 3*24*time.Hour) if err != nil { return "", "", err } diff --git a/internal/handler/http_handler/login_test.go b/internal/handler/http_handler/login_test.go index 20aff21..b300764 100644 --- a/internal/handler/http_handler/login_test.go +++ b/internal/handler/http_handler/login_test.go @@ -10,7 +10,7 @@ import ( "common/utils" "context" "fmt" - "gateway/config" + "gateway/internal/global" "gateway/internal/testutil" "github.com/gin-gonic/gin" "github.com/golang/mock/gomock" @@ -30,9 +30,9 @@ func (ts *LoginTestSuite) TestGenToken() { at, rt, err := genToken(context.Background(), int64(usn)) ts.Assert().NoError(err) - redisAt := redis.GetClient().Get(context.Background(), fmt.Sprintf(config.KeyUserAccessToken, usn)).Val() + redisAt := redis.GetClient().Get(context.Background(), fmt.Sprintf(global.KeyGatewayAccessToken, usn)).Val() ts.Assert().Equal(at, redisAt) - redisRt := redis.GetClient().Get(context.Background(), fmt.Sprintf(config.KeyUserRefreshToken, usn)).Val() + redisRt := redis.GetClient().Get(context.Background(), fmt.Sprintf(global.KeyGatewayRefreshToken, usn)).Val() ts.Assert().Equal(rt, redisRt) } @@ -138,7 +138,7 @@ func (ts *LoginTestSuite) TestRefreshToken() { }) defer monkey.Unpatch(utils.ParseToken) - redis.GetClient().Set(context.Background(), fmt.Sprintf(config.KeyUserRefreshToken, 1), "ab", redis2.KeepTTL) + redis.GetClient().Set(context.Background(), fmt.Sprintf(global.KeyGatewayRefreshToken, 1), "ab", redis2.KeepTTL) w, c := utils.CreateTestContext("POST", "/", &RefreshTokenReq{ RefreshToken: "abc", @@ -156,7 +156,7 @@ func (ts *LoginTestSuite) TestRefreshToken() { }) defer monkey.Unpatch(utils.ParseToken) - redis.GetClient().Set(context.Background(), fmt.Sprintf(config.KeyUserRefreshToken, 1), "abc", redis2.KeepTTL) + redis.GetClient().Set(context.Background(), fmt.Sprintf(global.KeyGatewayRefreshToken, 1), "abc", redis2.KeepTTL) w, c := utils.CreateTestContext("POST", "/", &RefreshTokenReq{ RefreshToken: "abc", diff --git a/internal/handler/ws_handler/client.go b/internal/handler/ws_handler/client/client.go similarity index 95% rename from internal/handler/ws_handler/client.go rename to internal/handler/ws_handler/client/client.go index a5da24b..736fcee 100644 --- a/internal/handler/ws_handler/client.go +++ b/internal/handler/ws_handler/client/client.go @@ -1,4 +1,4 @@ -package ws_handler +package client import ( "common/log" @@ -22,6 +22,7 @@ type Client struct { cancel context.CancelFunc // 取消上下文 heartBeat time.Time // 最后一次心跳 + Status int32 // 状态:0 登陆中 1 正常 2 离线 USN int64 // 用户ID SceneSID int64 // 场景服ID InstanceID int32 // 副本ID,副本类型 @@ -100,6 +101,7 @@ func (c *Client) onClose() { close(c.mailChan) c.mailChan = nil } + c.Status = 2 UserMgr.Delete(c.USN) c.onLeave() c.Done() diff --git a/internal/handler/ws_handler/event.go b/internal/handler/ws_handler/client/client_event.go similarity index 63% rename from internal/handler/ws_handler/event.go rename to internal/handler/ws_handler/client/client_event.go index 7a1ef51..5bc6f2a 100644 --- a/internal/handler/ws_handler/event.go +++ b/internal/handler/ws_handler/client/client_event.go @@ -1,4 +1,4 @@ -package ws_handler +package client type Event interface { } @@ -11,3 +11,7 @@ type ClientEvent struct { type PongEvent struct { Event } + +type SystemLoginSuccessEvent struct { + Event +} diff --git a/internal/handler/ws_handler/handler.go b/internal/handler/ws_handler/client/client_handler.go similarity index 88% rename from internal/handler/ws_handler/handler.go rename to internal/handler/ws_handler/client/client_handler.go index 9577e19..5e8c339 100644 --- a/internal/handler/ws_handler/handler.go +++ b/internal/handler/ws_handler/client/client_handler.go @@ -1,4 +1,4 @@ -package ws_handler +package client import ( "common/net/grpc/service" @@ -38,6 +38,14 @@ func (c *Client) handle(event Event) { } case *PongEvent: c.heartBeat = time.Now() + case *SystemLoginSuccessEvent: + if c.Status == 0 { + c.Status = 1 + UserMgr.Add(c.USN, c) + c.WriteMessage(sc_pb.MessageID_MESSAGE_ID_LOGIN_SUCCESS, &sc_pb.S2C_LoginSuccess{ + InstanceID: 1, + }) + } } } @@ -63,16 +71,17 @@ func (c *Client) onEnter(msg *sc_pb.C2S_EnterInstance) { } func (c *Client) onLeave() { + if c.SceneSID == 0 { + return + } client, err := service.SceneNewClient(c.SceneSID) if err != nil { c.logger.Errorf("SceneNewClient err: %v", err) return } _, err = client.Leave(c.ctx, &grpc_pb.LeaveReq{ - USN: c.USN, - GatewaySID: GatewaySID, - InstanceID: c.InstanceID, - UniqueNo: c.UniqueNo, + USN: c.USN, + UniqueNo: c.UniqueNo, }) if err != nil { c.logger.Errorf("leave err: %v", err) diff --git a/internal/handler/ws_handler/client_write.go b/internal/handler/ws_handler/client/client_write.go similarity index 98% rename from internal/handler/ws_handler/client_write.go rename to internal/handler/ws_handler/client/client_write.go index e76c046..ccd7f80 100644 --- a/internal/handler/ws_handler/client_write.go +++ b/internal/handler/ws_handler/client/client_write.go @@ -1,4 +1,4 @@ -package ws_handler +package client import ( "common/proto/sc/sc_pb" diff --git a/internal/handler/ws_handler/manager.go b/internal/handler/ws_handler/client/manager.go similarity index 88% rename from internal/handler/ws_handler/manager.go rename to internal/handler/ws_handler/client/manager.go index 8cbd1a4..769c15a 100644 --- a/internal/handler/ws_handler/manager.go +++ b/internal/handler/ws_handler/client/manager.go @@ -1,4 +1,4 @@ -package ws_handler +package client import ( "sync" @@ -56,3 +56,9 @@ func (m *userManager) GetByUSN(usn int64) *Client { defer m.RUnlock() return m.userMap[usn] } + +func (m *userManager) GetSize() int32 { + m.RLock() + defer m.RUnlock() + return int32(len(m.userMap)) +} diff --git a/internal/handler/ws_handler/login/login.go b/internal/handler/ws_handler/login/login.go new file mode 100644 index 0000000..3010ba0 --- /dev/null +++ b/internal/handler/ws_handler/login/login.go @@ -0,0 +1,181 @@ +package login + +import ( + "common/db/redis" + "common/log" + "common/net/grpc/service" + "common/proto/sc/sc_pb" + "common/proto/ss/grpc_pb" + "context" + "fmt" + "gateway/internal/global" + "gateway/internal/handler/ws_handler/client" + "sync" + "time" +) + +var loginQueue *Login + +// Login 登录队列结构 +type Login struct { + queue chan *User // 用户队列 + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup +} + +type User struct { + Cli *client.Client + Token string +} + +func NewLoginQueue(maxSize int) *Login { + ctx, cancel := context.WithCancel(context.Background()) + loginQueue = &Login{ + queue: make(chan *User, maxSize), + ctx: ctx, + cancel: cancel, + } + return loginQueue +} + +func GetLoginQueue() *Login { + return loginQueue +} + +// AddToLoginQueue 添加到登录队列 +func (l *Login) AddToLoginQueue(user *User) bool { + select { + case l.queue <- user: + return true + default: + } + return false +} + +func (l *Login) Start(num int) { + for i := 0; i < num; i++ { + l.wg.Add(1) + go func() { + defer l.wg.Done() + for { + select { + case <-l.ctx.Done(): + return + case user, ok := <-l.queue: + if ok { + l.StartLogin(user) + } + } + } + }() + } + // 定时从排队的队列中取用户 + l.wg.Add(1) + go func() { + defer l.wg.Done() + tick := time.NewTicker(1 * time.Second) + for { + select { + case <-tick.C: + for { + if client.UserMgr.GetSize() < global.MaxOnlineSize { + cli, err := GetQueueUp().Dequeue() + if err != nil { + log.Errorf("dequeue err: %v", err) + return + } + if cli == nil { + break + } + l.LoginSuccess(cli) + } else { + break + } + } + case <-l.ctx.Done(): + return + } + } + }() +} + +func (l *Login) Stop() { + l.cancel() + l.wg.Wait() +} + +// StartLogin 开始登录流程 +func (l *Login) StartLogin(user *User) { + if !l.CheckToken(user) { + user.Cli.WriteMessage(sc_pb.MessageID_MESSAGE_ID_KICK_OUT, &sc_pb.S2C_KickOut{ + ID: sc_pb.KickOutID_KICK_OUT_ID_TOKEN_INVALID, + }) + user.Cli.CloseClient() + return + } + if gatewaySID := l.CheckOnline(user); len(gatewaySID) > 0 { + // 如果在线就要踢,如果踢失败了就返回服务器繁忙,一般不应该走到这里 + if !l.KickUser(user.Cli.SceneSID, user.Cli.USN) { + user.Cli.WriteMessage(sc_pb.MessageID_MESSAGE_ID_KICK_OUT, &sc_pb.S2C_KickOut{ + ID: sc_pb.KickOutID_KICK_OUT_ID_SERVER_BUSY, + }) + user.Cli.CloseClient() + return + } + } + if client.UserMgr.GetSize() >= global.MaxOnlineSize { + // 如果人数满了就排队 + if err := GetQueueUp().Enqueue(user.Cli); err != nil { + user.Cli.WriteMessage(sc_pb.MessageID_MESSAGE_ID_KICK_OUT, &sc_pb.S2C_KickOut{ + ID: sc_pb.KickOutID_KICK_OUT_ID_QUEUE_UP_FULL, + }) + user.Cli.CloseClient() + return + } + // 告诉客户端正在排队 + position, ok := GetQueueUp().GetPosition(user.Cli.USN) + if !ok { + user.Cli.WriteMessage(sc_pb.MessageID_MESSAGE_ID_KICK_OUT, &sc_pb.S2C_KickOut{ + ID: sc_pb.KickOutID_KICK_OUT_ID_SERVER_BUSY, + }) + user.Cli.CloseClient() + return + } + user.Cli.WriteMessage(sc_pb.MessageID_MESSAGE_ID_QUEUE_UP, &sc_pb.S2C_QueueUp{ + QueueUpCount: int32(position), + }) + } else { + l.LoginSuccess(user.Cli) + } +} + +// CheckToken 校验Token是否有效 +func (l *Login) CheckToken(user *User) bool { + return redis.GetClient().Get(l.ctx, fmt.Sprintf(global.KeyGatewayAccessToken, user.Cli.USN)).Val() == user.Token +} + +// CheckOnline 校验是否在线 +func (l *Login) CheckOnline(user *User) string { + return redis.GetClient().HGet(l.ctx, fmt.Sprintf(global.KeyGatewayInfo, user.Cli.USN), global.HFieldInfoGatewaySID).Val() +} + +// KickUser 把玩家踢下线 +func (l *Login) KickUser(gatewaySID int64, usn int64) bool { + gc, err := service.GatewayNewClient(gatewaySID) + if err != nil { + log.Errorf("KickUser cannot find gateway client: %v, sid: %v", err, gatewaySID) + return false + } + _, err = gc.KickUser(l.ctx, &grpc_pb.KickUserReq{USN: usn}) + if err != nil { + log.Errorf("KickUser err: %v, sid: %v, usn: %v", err, gatewaySID, usn) + return false + } + return true +} + +// LoginSuccess 登录成功 +func (l *Login) LoginSuccess(user *client.Client) { + user.OnEvent(&client.SystemLoginSuccessEvent{}) +} diff --git a/internal/handler/ws_handler/login/queue_up.go b/internal/handler/ws_handler/login/queue_up.go new file mode 100644 index 0000000..e482f9d --- /dev/null +++ b/internal/handler/ws_handler/login/queue_up.go @@ -0,0 +1,104 @@ +package login + +import ( + "context" + "errors" + "gateway/internal/handler/ws_handler/client" + "sync" + "sync/atomic" +) + +var queueUp *QueueUp + +// QueueUp 排队队列结构 +type QueueUp struct { + queue chan *QueueUser // 用户队列 + waiting sync.Map // map[usn]*QueueUser + minTicket atomic.Int64 // 最小的Ticket + maxTicket atomic.Int64 // 最大的Ticket + ctx context.Context + cancel context.CancelFunc +} + +type QueueUser struct { + Cli *client.Client + Ticket int64 +} + +func NewQueueUp(maxSize int) *QueueUp { + ctx, cancel := context.WithCancel(context.Background()) + queueUp = &QueueUp{ + queue: make(chan *QueueUser, maxSize), + ctx: ctx, + cancel: cancel, + } + return queueUp +} + +func GetQueueUp() *QueueUp { + return queueUp +} + +// Enqueue 将用户加入排队队列 +func (q *QueueUp) Enqueue(cli *client.Client) error { + select { + case <-q.ctx.Done(): + return errors.New("queue stopped") + default: + } + + ticket := q.maxTicket.Add(1) + item := &QueueUser{Cli: cli, Ticket: ticket} + + select { + case q.queue <- item: + q.waiting.Store(cli.USN, item) + return nil + default: + return errors.New("queue is full") + } +} + +// Dequeue 从排队队列中取出下一个有效用户 +func (q *QueueUp) Dequeue() (*client.Client, error) { + select { + case item, ok := <-q.queue: + if ok { + q.minTicket.Store(item.Ticket) + if _, loaded := q.waiting.LoadAndDelete(item.Cli.USN); loaded { + return item.Cli, nil + } + return q.Dequeue() + } + case <-q.ctx.Done(): + return nil, q.ctx.Err() + default: + } + return nil, nil +} + +// GetPosition 返回用户前面还有多少人在排队 +func (q *QueueUp) GetPosition(usn int64) (int64, bool) { + val, ok := q.waiting.Load(usn) + if !ok { + return 0, false + } + user := val.(*QueueUser) + return user.Ticket - q.minTicket.Load() - 1, true +} + +// RemoveUser 安全移除用户(标记为取消) +func (q *QueueUp) RemoveUser(usn int64) bool { + _, loaded := q.waiting.LoadAndDelete(usn) + return loaded +} + +// GetQueueSize 获取当前排队人数 +func (q *QueueUp) GetQueueSize() int { + return len(q.queue) +} + +// Stop 停止整个队列服务 +func (q *QueueUp) Stop() { + q.cancel() +} diff --git a/internal/net/http_gateway/middleward.go b/internal/net/http_gateway/middleward.go index 0af71d4..095d233 100644 --- a/internal/net/http_gateway/middleward.go +++ b/internal/net/http_gateway/middleward.go @@ -5,6 +5,7 @@ import ( "common/utils" "fmt" "gateway/config" + "gateway/internal/global" "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" "go.uber.org/zap" @@ -45,7 +46,7 @@ func authJwt() gin.HandlerFunc { return func(c *gin.Context) { // 如果是Public接口,有Token就读,没有就算了 public := false - for _, path := range config.PublicPaths { + for _, path := range global.PublicPaths { if strings.HasPrefix(c.Request.URL.Path, path) { public = true break diff --git a/internal/net/ws_gateway/server.go b/internal/net/ws_gateway/server.go index 5506a43..0d8ddde 100644 --- a/internal/net/ws_gateway/server.go +++ b/internal/net/ws_gateway/server.go @@ -5,9 +5,10 @@ import ( "common/net/socket" "common/utils" "fmt" - "gateway/internal/handler/ws_handler" + "gateway/config" + "gateway/internal/handler/ws_handler/client" + "gateway/internal/handler/ws_handler/login" "go.uber.org/zap" - "strconv" "time" ) @@ -20,49 +21,42 @@ func (g *GatewayWsServer) OnOpen(conn socket.ISocketConn) ([]byte, socket.Action return nil, socket.None } -func (g *GatewayWsServer) OnHandShake(conn socket.ISocketConn, bytes []byte, callback func(conn socket.ISocketConn, bytes []byte)) socket.Action { +func (g *GatewayWsServer) OnHandShake(conn socket.ISocketConn) socket.Action { token, ok := conn.GetParam("token").(string) if !ok { - g.logger.Warnf("token is not string") + g.logger.Warnf("token is invalid") return socket.Close } - //claims, err := utils.ParseToken(token, config.Get().Auth.Secret) - //if err != nil { - // g.logger.Warnf("token is invalid") - // return socket.Close - //} - - t, _ := strconv.Atoi(token) - claims := utils.Claims{ - USN: int64(t), + claims, err := utils.ParseToken(token, config.Get().Auth.Secret) + if err != nil { + g.logger.Warnf("token is invalid") + return socket.Close + } + + cli := client.NewClient(claims.USN, conn) + conn.SetParam("client", cli) + if !login.GetLoginQueue().AddToLoginQueue(&login.User{Cli: cli, Token: token}) { + g.logger.Warnf("AddToLoginQueue err, login queue full, usn: %v", claims.USN) + return socket.Close } - go func(shResp []byte) { - if oldClient := ws_handler.UserMgr.GetByUSN(claims.USN); oldClient != nil { - oldClient.CloseClient() - } - client := ws_handler.NewClient(claims.USN, conn) - ws_handler.UserMgr.Add(claims.USN, client) - conn.SetParam("client", client) - callback(conn, shResp) - }(bytes) return socket.None } func (g *GatewayWsServer) OnMessage(conn socket.ISocketConn, bytes []byte) socket.Action { - client, ok := conn.GetParam("client").(*ws_handler.Client) - if !ok || client.USN == 0 { + cli, ok := conn.GetParam("client").(*client.Client) + if !ok || cli.USN == 0 || cli.Status != 1 { return socket.Close } - client.OnEvent(&ws_handler.ClientEvent{Msg: bytes}) + cli.OnEvent(&client.ClientEvent{Msg: bytes}) return socket.None } func (g *GatewayWsServer) OnPong(conn socket.ISocketConn) { - client, ok := conn.GetParam("client").(*ws_handler.Client) - if !ok || client.USN == 0 { + cli, ok := conn.GetParam("client").(*client.Client) + if !ok || cli.USN == 0 { return } - client.OnEvent(&ws_handler.PongEvent{}) + cli.OnEvent(&client.PongEvent{}) } func (g *GatewayWsServer) OnClose(_ socket.ISocketConn, _ error) socket.Action {