feat 网关鉴权

This commit is contained in:
2025-12-22 18:04:36 +08:00
parent 69cc960fe5
commit 670140e7d3
68 changed files with 1424 additions and 492 deletions

View File

@@ -23,7 +23,7 @@ func (s *Server) ToClient(server grpc_pb.Gateway_ToClientServer) error {
if args, err := server.Recv(); err != nil {
return
} else {
if args.UID == -1 {
if args.USN == -1 {
//utils.WorkerPool(ws_handler.UserMgr.GetAllInterface(), func(task interface{}) {
// client := task.(*ws_handler.Client)
@@ -46,7 +46,7 @@ func (s *Server) ToClient(server grpc_pb.Gateway_ToClientServer) error {
// client.WriteBytes(sc_pb.MessageID(args.MessageID), args.Payload)
//}
} else {
if client := ws_handler.UserMgr.GetByUID(args.UID); client != nil {
if client := ws_handler.UserMgr.GetByUSN(args.USN); client != nil {
client.WriteBytes(sc_pb.MessageID(args.MessageID), args.Payload)
}
}

View File

@@ -0,0 +1,118 @@
package http_handler
import (
"common/db/redis"
"common/log"
"common/net/grpc/service"
"common/net/http/http_resp"
"common/proto/ss/grpc_pb"
"common/utils"
"context"
"fmt"
"gateway/config"
"github.com/gin-gonic/gin"
"time"
)
// 这个模块处理用户登录
type LoginReq struct {
Phone string `json:"phone" binding:"required,min=1"`
Code string `json:"code" binding:"required,min=1"`
}
type LoginResp struct {
USN int64 `json:"usn"`
Name string `json:"name"`
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
}
func Login(c *gin.Context) {
req := &LoginReq{}
if err := c.ShouldBindJSON(req); err != nil {
http_resp.JsonBadRequest(c)
return
}
client, err := service.UserNewClientLB()
if err != nil {
log.Errorf("Login UserNewClientLB error: %v", err)
http_resp.JsonOK(c, http_resp.Error(http_resp.Failed))
return
}
login, err := client.Login(c, &grpc_pb.LoginReq{
Phone: req.Phone,
Code: req.Code,
})
if err != nil {
log.Errorf("Login Login error: %v", err)
http_resp.JsonOK(c, http_resp.Error(http_resp.Failed))
return
}
at, rt, err := genToken(c, login.USN)
http_resp.JsonOK(c, http_resp.Success(&LoginResp{
USN: login.USN,
Name: login.Name,
AccessToken: at,
RefreshToken: rt,
}))
}
type RefreshTokenReq struct {
RefreshToken string `json:"refreshToken" binding:"required,min=1"`
}
type RefreshTokenResp struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
}
func RefreshToken(c *gin.Context) {
req := &RefreshTokenReq{}
if err := c.ShouldBindJSON(req); err != nil {
http_resp.JsonBadRequest(c)
return
}
claims, err := utils.ParseToken(req.RefreshToken, config.Get().Auth.Secret)
if err != nil {
http_resp.JsonOK(c, http_resp.Error(http_resp.TokenInvalid))
return
}
if redis.GetClient().Get(c, fmt.Sprintf(config.KeyUserRefreshToken, claims.USN)).String() != req.RefreshToken {
http_resp.JsonOK(c, http_resp.Error(http_resp.TokenInvalid))
return
}
at, rt, err := genToken(c, claims.USN)
if err != nil {
log.Errorf("RefreshToken genToken error: %v, usn: %v", err, claims.USN)
http_resp.JsonOK(c, http_resp.Error(http_resp.Failed))
return
}
http_resp.JsonOK(c, http_resp.Success(&RefreshTokenResp{
AccessToken: at,
RefreshToken: rt,
}))
}
func genToken(ctx context.Context, usn int64) (string, string, error) {
at, err := genTokenOne(ctx, config.KeyUserAccessToken, usn, 2*time.Hour)
if err != nil {
return "", "", err
}
rt, err := genTokenOne(ctx, config.KeyUserRefreshToken, usn, 3*24*time.Hour)
if err != nil {
return "", "", err
}
return at, rt, nil
}
func genTokenOne(ctx context.Context, key string, usn int64, ttl time.Duration) (string, error) {
token, err := utils.GenToken(usn, config.Get().Auth.Secret, time.Duration(config.Get().Auth.Expire)*time.Second)
if err != nil {
return "", err
}
redis.GetClient().Set(ctx, fmt.Sprintf(key, usn), token, ttl)
return token, err
}

View File

@@ -22,17 +22,17 @@ type Client struct {
cancel context.CancelFunc // 取消上下文
heartBeat time.Time // 最后一次心跳
UID int32 // 用户ID
USN int64 // 用户ID
SceneSID int64 // 场景服ID
InstanceID int32 // 副本ID副本类型
UniqueNo int64 // 副本唯一编号
}
func NewClient(uid int32, conn socket.ISocketConn) *Client {
func NewClient(usn int64, conn socket.ISocketConn) *Client {
client := &Client{
UID: uid,
USN: usn,
conn: conn,
logger: log.GetLogger().Named(fmt.Sprintf("uid:%v", uid)),
logger: log.GetLogger().Named(fmt.Sprintf("usn:%v", usn)),
heartBeat: time.Now(),
mailChan: make(chan Event, 1024),
}
@@ -99,7 +99,7 @@ func (c *Client) onClose() {
close(c.mailChan)
c.mailChan = nil
}
UserMgr.Delete(c.UID)
UserMgr.Delete(c.USN)
c.onLeave()
c.Done()
}

View File

@@ -48,7 +48,7 @@ func (c *Client) onEnter(msg *sc_pb.C2S_EnterInstance) {
return
}
resp, err := client.Enter(c.ctx, &grpc_pb.EnterReq{
UID: c.UID,
USN: c.USN,
GatewaySID: GatewaySID,
InstanceID: msg.InstanceID,
})
@@ -69,7 +69,7 @@ func (c *Client) onLeave() {
return
}
_, err = client.Leave(c.ctx, &grpc_pb.LeaveReq{
UID: c.UID,
USN: c.USN,
GatewaySID: GatewaySID,
InstanceID: c.InstanceID,
UniqueNo: c.UniqueNo,
@@ -86,7 +86,7 @@ func (c *Client) onAction(msg *sc_pb.C2S_Action) {
}
if err := stream_client.SendMessageToScene(c.SceneSID, stream_client.FunAction, &grpc_pb.ActionReq{
UniqueNo: c.UniqueNo,
UID: c.UID,
USN: c.USN,
Action: int32(msg.Action),
DirX: msg.DirX,
DirY: msg.DirY,

View File

@@ -7,33 +7,33 @@ import (
var UserMgr *userManager
type userManager struct {
userMap map[int32]*Client
userMap map[int64]*Client
sync.RWMutex
}
func init() {
UserMgr = &userManager{
userMap: make(map[int32]*Client),
userMap: make(map[int64]*Client),
}
}
func (m *userManager) Add(uid int32, client *Client) {
func (m *userManager) Add(usn int64, client *Client) {
m.Lock()
defer m.Unlock()
m.userMap[uid] = client
m.userMap[usn] = client
}
func (m *userManager) Delete(uid int32) {
func (m *userManager) Delete(usn int64) {
m.Lock()
defer m.Unlock()
delete(m.userMap, uid)
delete(m.userMap, usn)
}
func (m *userManager) GetAll() map[int32]*Client {
func (m *userManager) GetAll() map[int64]*Client {
m.RLock()
defer m.RUnlock()
copyMap := make(map[int32]*Client, len(m.userMap))
copyMap := make(map[int64]*Client, len(m.userMap))
for k, v := range m.userMap {
copyMap[k] = v
}
@@ -51,8 +51,8 @@ func (m *userManager) GetAllInterface() []interface{} {
return r
}
func (m *userManager) GetByUID(uid int32) *Client {
func (m *userManager) GetByUSN(usn int64) *Client {
m.RLock()
defer m.RUnlock()
return m.userMap[uid]
return m.userMap[usn]
}

View File

@@ -1,10 +1,15 @@
package http_gateway
import (
"common/net/http/http_resp"
"common/utils"
"fmt"
"gateway/config"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"strconv"
"strings"
"time"
)
@@ -35,3 +40,28 @@ func ginLogger(logger *zap.SugaredLogger) gin.HandlerFunc {
)
}
}
func authJwt() gin.HandlerFunc {
return func(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
http_resp.AbortUnauthorized(c)
return
}
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
http_resp.AbortUnauthorized(c)
return
}
claims, err := utils.ParseToken(parts[1], config.Get().Auth.Secret)
if err != nil {
http_resp.AbortUnauthorized(c)
return
}
c.Request.Header.Set("X-Usn", strconv.Itoa(int(claims.USN)))
c.Next()
}
}

View File

@@ -2,13 +2,16 @@ package http_gateway
import (
"common/log"
"common/net/grpc/service"
"common/net/http/http_resp"
wrapper2 "gateway/internal/net/http_gateway/wrapper"
"common/proto/ss/grpc_pb"
"context"
"gateway/internal/handler/http_handler"
"gateway/internal/net/http_gateway/wrapper"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"google.golang.org/protobuf/encoding/protojson"
"net/http"
)
func InitServeMux() *runtime.ServeMux {
@@ -21,16 +24,22 @@ func InitServeMux() *runtime.ServeMux {
DiscardUnknown: true,
},
}
unifiedMarshaler := wrapper2.NewWrappedMarshaler(baseMarshaler)
unifiedMarshaler := wrapper.NewWrappedMarshaler(baseMarshaler)
mux := runtime.NewServeMux(
runtime.WithMarshalerOption(runtime.MIMEWildcard, unifiedMarshaler),
runtime.WithErrorHandler(wrapper2.ErrorHandler),
runtime.WithErrorHandler(wrapper.ErrorHandler),
runtime.WithIncomingHeaderMatcher(func(header string) (string, bool) {
if header == "X-Usn" {
return "X-Usn", true
}
return runtime.DefaultHeaderMatcher(header)
}),
)
return mux
}
func InitRouter(mux *runtime.ServeMux) *gin.Engine {
func InitRouter() *gin.Engine {
gin.SetMode(gin.ReleaseMode)
r := gin.New()
@@ -42,13 +51,42 @@ func InitRouter(mux *runtime.ServeMux) *gin.Engine {
r.HandleMethodNotAllowed = true
r.NoMethod(func(c *gin.Context) {
c.JSON(http.StatusMethodNotAllowed, http_resp.Error(http_resp.Failed.Code(), "Method Not Allowed"))
http_resp.JsonMethodNotAllowed(c)
})
r.NoRoute(func(c *gin.Context) {
c.JSON(http.StatusNotFound, http_resp.Error(http_resp.Failed.Code(), "Endpoint Not Found"))
http_resp.JsonNotFound(c)
})
r.Any("/*any", gin.WrapH(mux))
initBaseRoute(r.Group("/"))
auth := r.Group("/")
auth.Use(authJwt())
// 用户中心
initUserPath(auth)
return r
}
func initBaseRoute(r *gin.RouterGroup) {
g := r.Group("/gw")
g.POST("/login", http_handler.Login)
g.POST("/refresh_token", http_handler.RefreshToken)
}
func initUserPath(r *gin.RouterGroup) {
g := r.Group("/user")
client, err := service.UserNewClientLB()
if err != nil {
log.Errorf("get user conn failed: %v", err)
return
}
gwMux := InitServeMux()
if err = grpc_pb.RegisterUserHandlerClient(context.Background(), gwMux, client); err != nil {
log.Errorf("RegisterUserHandlerClient err: %v", err)
return
}
g.Any("/*path", gin.WrapH(gwMux))
}

View File

@@ -17,7 +17,7 @@ func ErrorHandler(_ context.Context, _ *runtime.ServeMux, _ runtime.Marshaler, w
if !ok {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
_ = json.NewEncoder(w).Encode(http_resp.Error(http_resp.Failed.Code(), http_resp.Failed.Error()))
_ = json.NewEncoder(w).Encode(http_resp.Error(http_resp.Failed))
return
}
@@ -33,13 +33,15 @@ func ErrorHandler(_ context.Context, _ *runtime.ServeMux, _ runtime.Marshaler, w
code = http_resp.Failed.Code()
msg = http_resp.Failed.Error()
}
if st.Code() == codes.Unknown || st.Code() == codes.Unimplemented {
if st.Code() == codes.Unknown ||
st.Code() == codes.Unimplemented ||
st.Code() == codes.NotFound {
msg = st.Message()
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(grpcCodeToHTTPCode(st.Code()))
_ = json.NewEncoder(w).Encode(http_resp.Error(code, msg))
_ = json.NewEncoder(w).Encode(http_resp.Error(http_resp.NewCode(code, msg)))
}
// 这里定义 Internal 属于业务错误,其他的属于 500 报错
@@ -47,7 +49,7 @@ func grpcCodeToHTTPCode(c codes.Code) int {
switch c {
case codes.OK, codes.Unknown:
return http.StatusOK
case codes.Unimplemented:
case codes.Unimplemented, codes.NotFound:
return http.StatusNotFound
default:
return http.StatusInternalServerError

View File

@@ -20,7 +20,7 @@ func NewWrappedMarshaler(inner runtime.Marshaler) *WrappedMarshaler {
func (w *WrappedMarshaler) Marshal(v interface{}) ([]byte, error) {
dataBytes, err := w.inner.Marshal(v)
if err != nil {
return json.Marshal(http_resp.Error(http_resp.Failed.Code(), http_resp.Failed.Error()))
return json.Marshal(http_resp.Error(http_resp.Failed))
}
return json.Marshal(http_resp.Success(json.RawMessage(dataBytes)))
}

View File

@@ -21,7 +21,7 @@ func (g *GatewayWsServer) OnOpen(conn socket.ISocketConn) ([]byte, socket.Action
func (g *GatewayWsServer) OnHandShake(conn socket.ISocketConn) {
token, ok := conn.GetParam("token").(string)
if !ok || len(token) == 0 {
if !ok || token == "" {
g.logger.Warnf("token is not string")
_ = conn.Close()
return
@@ -30,17 +30,17 @@ func (g *GatewayWsServer) OnHandShake(conn socket.ISocketConn) {
if err != nil {
_ = conn.Close()
}
if oldClient := ws_handler2.UserMgr.GetByUID(int32(t)); oldClient != nil {
if oldClient := ws_handler2.UserMgr.GetByUSN(int64(t)); oldClient != nil {
oldClient.CloseClient()
}
client := ws_handler2.NewClient(int32(t), conn)
ws_handler2.UserMgr.Add(int32(t), client)
client := ws_handler2.NewClient(int64(t), conn)
ws_handler2.UserMgr.Add(int64(t), client)
conn.SetParam("client", client)
}
func (g *GatewayWsServer) OnMessage(conn socket.ISocketConn, bytes []byte) socket.Action {
client, ok := conn.GetParam("client").(*ws_handler2.Client)
if !ok || client.UID == 0 {
if !ok || client.USN == 0 {
return socket.Close
}
client.OnEvent(&ws_handler2.ClientEvent{Msg: bytes})
@@ -49,7 +49,7 @@ func (g *GatewayWsServer) OnMessage(conn socket.ISocketConn, bytes []byte) socke
func (g *GatewayWsServer) OnPong(conn socket.ISocketConn) {
client, ok := conn.GetParam("client").(*ws_handler2.Client)
if !ok || client.UID == 0 {
if !ok || client.USN == 0 {
return
}
client.OnEvent(&ws_handler2.PongEvent{})