feat 排队
This commit is contained in:
@@ -10,6 +10,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"gateway/config"
|
||||
"gateway/internal/global"
|
||||
"github.com/gin-gonic/gin"
|
||||
"time"
|
||||
)
|
||||
@@ -79,7 +80,7 @@ func RefreshToken(c *gin.Context) {
|
||||
http_resp.JsonOK(c, http_resp.Error(http_resp.TokenInvalid))
|
||||
return
|
||||
}
|
||||
if redis.GetClient().Get(c, fmt.Sprintf(config.KeyUserRefreshToken, claims.USN)).Val() != req.RefreshToken {
|
||||
if redis.GetClient().Get(c, fmt.Sprintf(global.KeyGatewayRefreshToken, claims.USN)).Val() != req.RefreshToken {
|
||||
http_resp.JsonOK(c, http_resp.Error(http_resp.TokenInvalid))
|
||||
return
|
||||
}
|
||||
@@ -97,11 +98,11 @@ func RefreshToken(c *gin.Context) {
|
||||
}
|
||||
|
||||
func genToken(ctx context.Context, usn int64) (string, string, error) {
|
||||
at, err := genTokenOne(ctx, config.KeyUserAccessToken, usn, 2*time.Hour)
|
||||
at, err := genTokenOne(ctx, global.KeyGatewayAccessToken, usn, 2*time.Hour)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
rt, err := genTokenOne(ctx, config.KeyUserRefreshToken, usn, 3*24*time.Hour)
|
||||
rt, err := genTokenOne(ctx, global.KeyGatewayRefreshToken, usn, 3*24*time.Hour)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"common/utils"
|
||||
"context"
|
||||
"fmt"
|
||||
"gateway/config"
|
||||
"gateway/internal/global"
|
||||
"gateway/internal/testutil"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang/mock/gomock"
|
||||
@@ -30,9 +30,9 @@ func (ts *LoginTestSuite) TestGenToken() {
|
||||
at, rt, err := genToken(context.Background(), int64(usn))
|
||||
ts.Assert().NoError(err)
|
||||
|
||||
redisAt := redis.GetClient().Get(context.Background(), fmt.Sprintf(config.KeyUserAccessToken, usn)).Val()
|
||||
redisAt := redis.GetClient().Get(context.Background(), fmt.Sprintf(global.KeyGatewayAccessToken, usn)).Val()
|
||||
ts.Assert().Equal(at, redisAt)
|
||||
redisRt := redis.GetClient().Get(context.Background(), fmt.Sprintf(config.KeyUserRefreshToken, usn)).Val()
|
||||
redisRt := redis.GetClient().Get(context.Background(), fmt.Sprintf(global.KeyGatewayRefreshToken, usn)).Val()
|
||||
ts.Assert().Equal(rt, redisRt)
|
||||
}
|
||||
|
||||
@@ -138,7 +138,7 @@ func (ts *LoginTestSuite) TestRefreshToken() {
|
||||
})
|
||||
defer monkey.Unpatch(utils.ParseToken)
|
||||
|
||||
redis.GetClient().Set(context.Background(), fmt.Sprintf(config.KeyUserRefreshToken, 1), "ab", redis2.KeepTTL)
|
||||
redis.GetClient().Set(context.Background(), fmt.Sprintf(global.KeyGatewayRefreshToken, 1), "ab", redis2.KeepTTL)
|
||||
|
||||
w, c := utils.CreateTestContext("POST", "/", &RefreshTokenReq{
|
||||
RefreshToken: "abc",
|
||||
@@ -156,7 +156,7 @@ func (ts *LoginTestSuite) TestRefreshToken() {
|
||||
})
|
||||
defer monkey.Unpatch(utils.ParseToken)
|
||||
|
||||
redis.GetClient().Set(context.Background(), fmt.Sprintf(config.KeyUserRefreshToken, 1), "abc", redis2.KeepTTL)
|
||||
redis.GetClient().Set(context.Background(), fmt.Sprintf(global.KeyGatewayRefreshToken, 1), "abc", redis2.KeepTTL)
|
||||
|
||||
w, c := utils.CreateTestContext("POST", "/", &RefreshTokenReq{
|
||||
RefreshToken: "abc",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package ws_handler
|
||||
package client
|
||||
|
||||
import (
|
||||
"common/log"
|
||||
@@ -22,6 +22,7 @@ type Client struct {
|
||||
cancel context.CancelFunc // 取消上下文
|
||||
heartBeat time.Time // 最后一次心跳
|
||||
|
||||
Status int32 // 状态:0 登陆中 1 正常 2 离线
|
||||
USN int64 // 用户ID
|
||||
SceneSID int64 // 场景服ID
|
||||
InstanceID int32 // 副本ID,副本类型
|
||||
@@ -100,6 +101,7 @@ func (c *Client) onClose() {
|
||||
close(c.mailChan)
|
||||
c.mailChan = nil
|
||||
}
|
||||
c.Status = 2
|
||||
UserMgr.Delete(c.USN)
|
||||
c.onLeave()
|
||||
c.Done()
|
||||
@@ -1,4 +1,4 @@
|
||||
package ws_handler
|
||||
package client
|
||||
|
||||
type Event interface {
|
||||
}
|
||||
@@ -11,3 +11,7 @@ type ClientEvent struct {
|
||||
type PongEvent struct {
|
||||
Event
|
||||
}
|
||||
|
||||
type SystemLoginSuccessEvent struct {
|
||||
Event
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package ws_handler
|
||||
package client
|
||||
|
||||
import (
|
||||
"common/net/grpc/service"
|
||||
@@ -38,6 +38,14 @@ func (c *Client) handle(event Event) {
|
||||
}
|
||||
case *PongEvent:
|
||||
c.heartBeat = time.Now()
|
||||
case *SystemLoginSuccessEvent:
|
||||
if c.Status == 0 {
|
||||
c.Status = 1
|
||||
UserMgr.Add(c.USN, c)
|
||||
c.WriteMessage(sc_pb.MessageID_MESSAGE_ID_LOGIN_SUCCESS, &sc_pb.S2C_LoginSuccess{
|
||||
InstanceID: 1,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -63,16 +71,17 @@ func (c *Client) onEnter(msg *sc_pb.C2S_EnterInstance) {
|
||||
}
|
||||
|
||||
func (c *Client) onLeave() {
|
||||
if c.SceneSID == 0 {
|
||||
return
|
||||
}
|
||||
client, err := service.SceneNewClient(c.SceneSID)
|
||||
if err != nil {
|
||||
c.logger.Errorf("SceneNewClient err: %v", err)
|
||||
return
|
||||
}
|
||||
_, err = client.Leave(c.ctx, &grpc_pb.LeaveReq{
|
||||
USN: c.USN,
|
||||
GatewaySID: GatewaySID,
|
||||
InstanceID: c.InstanceID,
|
||||
UniqueNo: c.UniqueNo,
|
||||
USN: c.USN,
|
||||
UniqueNo: c.UniqueNo,
|
||||
})
|
||||
if err != nil {
|
||||
c.logger.Errorf("leave err: %v", err)
|
||||
@@ -1,4 +1,4 @@
|
||||
package ws_handler
|
||||
package client
|
||||
|
||||
import (
|
||||
"common/proto/sc/sc_pb"
|
||||
@@ -1,4 +1,4 @@
|
||||
package ws_handler
|
||||
package client
|
||||
|
||||
import (
|
||||
"sync"
|
||||
@@ -56,3 +56,9 @@ func (m *userManager) GetByUSN(usn int64) *Client {
|
||||
defer m.RUnlock()
|
||||
return m.userMap[usn]
|
||||
}
|
||||
|
||||
func (m *userManager) GetSize() int32 {
|
||||
m.RLock()
|
||||
defer m.RUnlock()
|
||||
return int32(len(m.userMap))
|
||||
}
|
||||
181
internal/handler/ws_handler/login/login.go
Normal file
181
internal/handler/ws_handler/login/login.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package login
|
||||
|
||||
import (
|
||||
"common/db/redis"
|
||||
"common/log"
|
||||
"common/net/grpc/service"
|
||||
"common/proto/sc/sc_pb"
|
||||
"common/proto/ss/grpc_pb"
|
||||
"context"
|
||||
"fmt"
|
||||
"gateway/internal/global"
|
||||
"gateway/internal/handler/ws_handler/client"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var loginQueue *Login
|
||||
|
||||
// Login 登录队列结构
|
||||
type Login struct {
|
||||
queue chan *User // 用户队列
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
type User struct {
|
||||
Cli *client.Client
|
||||
Token string
|
||||
}
|
||||
|
||||
func NewLoginQueue(maxSize int) *Login {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
loginQueue = &Login{
|
||||
queue: make(chan *User, maxSize),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
return loginQueue
|
||||
}
|
||||
|
||||
func GetLoginQueue() *Login {
|
||||
return loginQueue
|
||||
}
|
||||
|
||||
// AddToLoginQueue 添加到登录队列
|
||||
func (l *Login) AddToLoginQueue(user *User) bool {
|
||||
select {
|
||||
case l.queue <- user:
|
||||
return true
|
||||
default:
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (l *Login) Start(num int) {
|
||||
for i := 0; i < num; i++ {
|
||||
l.wg.Add(1)
|
||||
go func() {
|
||||
defer l.wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-l.ctx.Done():
|
||||
return
|
||||
case user, ok := <-l.queue:
|
||||
if ok {
|
||||
l.StartLogin(user)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
// 定时从排队的队列中取用户
|
||||
l.wg.Add(1)
|
||||
go func() {
|
||||
defer l.wg.Done()
|
||||
tick := time.NewTicker(1 * time.Second)
|
||||
for {
|
||||
select {
|
||||
case <-tick.C:
|
||||
for {
|
||||
if client.UserMgr.GetSize() < global.MaxOnlineSize {
|
||||
cli, err := GetQueueUp().Dequeue()
|
||||
if err != nil {
|
||||
log.Errorf("dequeue err: %v", err)
|
||||
return
|
||||
}
|
||||
if cli == nil {
|
||||
break
|
||||
}
|
||||
l.LoginSuccess(cli)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
case <-l.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (l *Login) Stop() {
|
||||
l.cancel()
|
||||
l.wg.Wait()
|
||||
}
|
||||
|
||||
// StartLogin 开始登录流程
|
||||
func (l *Login) StartLogin(user *User) {
|
||||
if !l.CheckToken(user) {
|
||||
user.Cli.WriteMessage(sc_pb.MessageID_MESSAGE_ID_KICK_OUT, &sc_pb.S2C_KickOut{
|
||||
ID: sc_pb.KickOutID_KICK_OUT_ID_TOKEN_INVALID,
|
||||
})
|
||||
user.Cli.CloseClient()
|
||||
return
|
||||
}
|
||||
if gatewaySID := l.CheckOnline(user); len(gatewaySID) > 0 {
|
||||
// 如果在线就要踢,如果踢失败了就返回服务器繁忙,一般不应该走到这里
|
||||
if !l.KickUser(user.Cli.SceneSID, user.Cli.USN) {
|
||||
user.Cli.WriteMessage(sc_pb.MessageID_MESSAGE_ID_KICK_OUT, &sc_pb.S2C_KickOut{
|
||||
ID: sc_pb.KickOutID_KICK_OUT_ID_SERVER_BUSY,
|
||||
})
|
||||
user.Cli.CloseClient()
|
||||
return
|
||||
}
|
||||
}
|
||||
if client.UserMgr.GetSize() >= global.MaxOnlineSize {
|
||||
// 如果人数满了就排队
|
||||
if err := GetQueueUp().Enqueue(user.Cli); err != nil {
|
||||
user.Cli.WriteMessage(sc_pb.MessageID_MESSAGE_ID_KICK_OUT, &sc_pb.S2C_KickOut{
|
||||
ID: sc_pb.KickOutID_KICK_OUT_ID_QUEUE_UP_FULL,
|
||||
})
|
||||
user.Cli.CloseClient()
|
||||
return
|
||||
}
|
||||
// 告诉客户端正在排队
|
||||
position, ok := GetQueueUp().GetPosition(user.Cli.USN)
|
||||
if !ok {
|
||||
user.Cli.WriteMessage(sc_pb.MessageID_MESSAGE_ID_KICK_OUT, &sc_pb.S2C_KickOut{
|
||||
ID: sc_pb.KickOutID_KICK_OUT_ID_SERVER_BUSY,
|
||||
})
|
||||
user.Cli.CloseClient()
|
||||
return
|
||||
}
|
||||
user.Cli.WriteMessage(sc_pb.MessageID_MESSAGE_ID_QUEUE_UP, &sc_pb.S2C_QueueUp{
|
||||
QueueUpCount: int32(position),
|
||||
})
|
||||
} else {
|
||||
l.LoginSuccess(user.Cli)
|
||||
}
|
||||
}
|
||||
|
||||
// CheckToken 校验Token是否有效
|
||||
func (l *Login) CheckToken(user *User) bool {
|
||||
return redis.GetClient().Get(l.ctx, fmt.Sprintf(global.KeyGatewayAccessToken, user.Cli.USN)).Val() == user.Token
|
||||
}
|
||||
|
||||
// CheckOnline 校验是否在线
|
||||
func (l *Login) CheckOnline(user *User) string {
|
||||
return redis.GetClient().HGet(l.ctx, fmt.Sprintf(global.KeyGatewayInfo, user.Cli.USN), global.HFieldInfoGatewaySID).Val()
|
||||
}
|
||||
|
||||
// KickUser 把玩家踢下线
|
||||
func (l *Login) KickUser(gatewaySID int64, usn int64) bool {
|
||||
gc, err := service.GatewayNewClient(gatewaySID)
|
||||
if err != nil {
|
||||
log.Errorf("KickUser cannot find gateway client: %v, sid: %v", err, gatewaySID)
|
||||
return false
|
||||
}
|
||||
_, err = gc.KickUser(l.ctx, &grpc_pb.KickUserReq{USN: usn})
|
||||
if err != nil {
|
||||
log.Errorf("KickUser err: %v, sid: %v, usn: %v", err, gatewaySID, usn)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// LoginSuccess 登录成功
|
||||
func (l *Login) LoginSuccess(user *client.Client) {
|
||||
user.OnEvent(&client.SystemLoginSuccessEvent{})
|
||||
}
|
||||
104
internal/handler/ws_handler/login/queue_up.go
Normal file
104
internal/handler/ws_handler/login/queue_up.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package login
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"gateway/internal/handler/ws_handler/client"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
var queueUp *QueueUp
|
||||
|
||||
// QueueUp 排队队列结构
|
||||
type QueueUp struct {
|
||||
queue chan *QueueUser // 用户队列
|
||||
waiting sync.Map // map[usn]*QueueUser
|
||||
minTicket atomic.Int64 // 最小的Ticket
|
||||
maxTicket atomic.Int64 // 最大的Ticket
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
type QueueUser struct {
|
||||
Cli *client.Client
|
||||
Ticket int64
|
||||
}
|
||||
|
||||
func NewQueueUp(maxSize int) *QueueUp {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
queueUp = &QueueUp{
|
||||
queue: make(chan *QueueUser, maxSize),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
return queueUp
|
||||
}
|
||||
|
||||
func GetQueueUp() *QueueUp {
|
||||
return queueUp
|
||||
}
|
||||
|
||||
// Enqueue 将用户加入排队队列
|
||||
func (q *QueueUp) Enqueue(cli *client.Client) error {
|
||||
select {
|
||||
case <-q.ctx.Done():
|
||||
return errors.New("queue stopped")
|
||||
default:
|
||||
}
|
||||
|
||||
ticket := q.maxTicket.Add(1)
|
||||
item := &QueueUser{Cli: cli, Ticket: ticket}
|
||||
|
||||
select {
|
||||
case q.queue <- item:
|
||||
q.waiting.Store(cli.USN, item)
|
||||
return nil
|
||||
default:
|
||||
return errors.New("queue is full")
|
||||
}
|
||||
}
|
||||
|
||||
// Dequeue 从排队队列中取出下一个有效用户
|
||||
func (q *QueueUp) Dequeue() (*client.Client, error) {
|
||||
select {
|
||||
case item, ok := <-q.queue:
|
||||
if ok {
|
||||
q.minTicket.Store(item.Ticket)
|
||||
if _, loaded := q.waiting.LoadAndDelete(item.Cli.USN); loaded {
|
||||
return item.Cli, nil
|
||||
}
|
||||
return q.Dequeue()
|
||||
}
|
||||
case <-q.ctx.Done():
|
||||
return nil, q.ctx.Err()
|
||||
default:
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// GetPosition 返回用户前面还有多少人在排队
|
||||
func (q *QueueUp) GetPosition(usn int64) (int64, bool) {
|
||||
val, ok := q.waiting.Load(usn)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
user := val.(*QueueUser)
|
||||
return user.Ticket - q.minTicket.Load() - 1, true
|
||||
}
|
||||
|
||||
// RemoveUser 安全移除用户(标记为取消)
|
||||
func (q *QueueUp) RemoveUser(usn int64) bool {
|
||||
_, loaded := q.waiting.LoadAndDelete(usn)
|
||||
return loaded
|
||||
}
|
||||
|
||||
// GetQueueSize 获取当前排队人数
|
||||
func (q *QueueUp) GetQueueSize() int {
|
||||
return len(q.queue)
|
||||
}
|
||||
|
||||
// Stop 停止整个队列服务
|
||||
func (q *QueueUp) Stop() {
|
||||
q.cancel()
|
||||
}
|
||||
Reference in New Issue
Block a user