feat 网关鉴权
This commit is contained in:
@@ -1,10 +1,15 @@
|
||||
package http_gateway
|
||||
|
||||
import (
|
||||
"common/net/http/http_resp"
|
||||
"common/utils"
|
||||
"fmt"
|
||||
"gateway/config"
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -35,3 +40,28 @@ func ginLogger(logger *zap.SugaredLogger) gin.HandlerFunc {
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func authJwt() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
http_resp.AbortUnauthorized(c)
|
||||
return
|
||||
}
|
||||
|
||||
parts := strings.Split(authHeader, " ")
|
||||
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" {
|
||||
http_resp.AbortUnauthorized(c)
|
||||
return
|
||||
}
|
||||
|
||||
claims, err := utils.ParseToken(parts[1], config.Get().Auth.Secret)
|
||||
if err != nil {
|
||||
http_resp.AbortUnauthorized(c)
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Header.Set("X-Usn", strconv.Itoa(int(claims.USN)))
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,13 +2,16 @@ package http_gateway
|
||||
|
||||
import (
|
||||
"common/log"
|
||||
"common/net/grpc/service"
|
||||
"common/net/http/http_resp"
|
||||
wrapper2 "gateway/internal/net/http_gateway/wrapper"
|
||||
"common/proto/ss/grpc_pb"
|
||||
"context"
|
||||
"gateway/internal/handler/http_handler"
|
||||
"gateway/internal/net/http_gateway/wrapper"
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func InitServeMux() *runtime.ServeMux {
|
||||
@@ -21,16 +24,22 @@ func InitServeMux() *runtime.ServeMux {
|
||||
DiscardUnknown: true,
|
||||
},
|
||||
}
|
||||
unifiedMarshaler := wrapper2.NewWrappedMarshaler(baseMarshaler)
|
||||
unifiedMarshaler := wrapper.NewWrappedMarshaler(baseMarshaler)
|
||||
|
||||
mux := runtime.NewServeMux(
|
||||
runtime.WithMarshalerOption(runtime.MIMEWildcard, unifiedMarshaler),
|
||||
runtime.WithErrorHandler(wrapper2.ErrorHandler),
|
||||
runtime.WithErrorHandler(wrapper.ErrorHandler),
|
||||
runtime.WithIncomingHeaderMatcher(func(header string) (string, bool) {
|
||||
if header == "X-Usn" {
|
||||
return "X-Usn", true
|
||||
}
|
||||
return runtime.DefaultHeaderMatcher(header)
|
||||
}),
|
||||
)
|
||||
return mux
|
||||
}
|
||||
|
||||
func InitRouter(mux *runtime.ServeMux) *gin.Engine {
|
||||
func InitRouter() *gin.Engine {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
|
||||
r := gin.New()
|
||||
@@ -42,13 +51,42 @@ func InitRouter(mux *runtime.ServeMux) *gin.Engine {
|
||||
|
||||
r.HandleMethodNotAllowed = true
|
||||
r.NoMethod(func(c *gin.Context) {
|
||||
c.JSON(http.StatusMethodNotAllowed, http_resp.Error(http_resp.Failed.Code(), "Method Not Allowed"))
|
||||
http_resp.JsonMethodNotAllowed(c)
|
||||
})
|
||||
r.NoRoute(func(c *gin.Context) {
|
||||
c.JSON(http.StatusNotFound, http_resp.Error(http_resp.Failed.Code(), "Endpoint Not Found"))
|
||||
http_resp.JsonNotFound(c)
|
||||
})
|
||||
|
||||
r.Any("/*any", gin.WrapH(mux))
|
||||
initBaseRoute(r.Group("/"))
|
||||
|
||||
auth := r.Group("/")
|
||||
auth.Use(authJwt())
|
||||
|
||||
// 用户中心
|
||||
initUserPath(auth)
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func initBaseRoute(r *gin.RouterGroup) {
|
||||
g := r.Group("/gw")
|
||||
g.POST("/login", http_handler.Login)
|
||||
g.POST("/refresh_token", http_handler.RefreshToken)
|
||||
}
|
||||
|
||||
func initUserPath(r *gin.RouterGroup) {
|
||||
g := r.Group("/user")
|
||||
client, err := service.UserNewClientLB()
|
||||
if err != nil {
|
||||
log.Errorf("get user conn failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
gwMux := InitServeMux()
|
||||
if err = grpc_pb.RegisterUserHandlerClient(context.Background(), gwMux, client); err != nil {
|
||||
log.Errorf("RegisterUserHandlerClient err: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
g.Any("/*path", gin.WrapH(gwMux))
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ func ErrorHandler(_ context.Context, _ *runtime.ServeMux, _ runtime.Marshaler, w
|
||||
if !ok {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_ = json.NewEncoder(w).Encode(http_resp.Error(http_resp.Failed.Code(), http_resp.Failed.Error()))
|
||||
_ = json.NewEncoder(w).Encode(http_resp.Error(http_resp.Failed))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -33,13 +33,15 @@ func ErrorHandler(_ context.Context, _ *runtime.ServeMux, _ runtime.Marshaler, w
|
||||
code = http_resp.Failed.Code()
|
||||
msg = http_resp.Failed.Error()
|
||||
}
|
||||
if st.Code() == codes.Unknown || st.Code() == codes.Unimplemented {
|
||||
if st.Code() == codes.Unknown ||
|
||||
st.Code() == codes.Unimplemented ||
|
||||
st.Code() == codes.NotFound {
|
||||
msg = st.Message()
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(grpcCodeToHTTPCode(st.Code()))
|
||||
_ = json.NewEncoder(w).Encode(http_resp.Error(code, msg))
|
||||
_ = json.NewEncoder(w).Encode(http_resp.Error(http_resp.NewCode(code, msg)))
|
||||
}
|
||||
|
||||
// 这里定义 Internal 属于业务错误,其他的属于 500 报错
|
||||
@@ -47,7 +49,7 @@ func grpcCodeToHTTPCode(c codes.Code) int {
|
||||
switch c {
|
||||
case codes.OK, codes.Unknown:
|
||||
return http.StatusOK
|
||||
case codes.Unimplemented:
|
||||
case codes.Unimplemented, codes.NotFound:
|
||||
return http.StatusNotFound
|
||||
default:
|
||||
return http.StatusInternalServerError
|
||||
|
||||
@@ -20,7 +20,7 @@ func NewWrappedMarshaler(inner runtime.Marshaler) *WrappedMarshaler {
|
||||
func (w *WrappedMarshaler) Marshal(v interface{}) ([]byte, error) {
|
||||
dataBytes, err := w.inner.Marshal(v)
|
||||
if err != nil {
|
||||
return json.Marshal(http_resp.Error(http_resp.Failed.Code(), http_resp.Failed.Error()))
|
||||
return json.Marshal(http_resp.Error(http_resp.Failed))
|
||||
}
|
||||
return json.Marshal(http_resp.Success(json.RawMessage(dataBytes)))
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ func (g *GatewayWsServer) OnOpen(conn socket.ISocketConn) ([]byte, socket.Action
|
||||
|
||||
func (g *GatewayWsServer) OnHandShake(conn socket.ISocketConn) {
|
||||
token, ok := conn.GetParam("token").(string)
|
||||
if !ok || len(token) == 0 {
|
||||
if !ok || token == "" {
|
||||
g.logger.Warnf("token is not string")
|
||||
_ = conn.Close()
|
||||
return
|
||||
@@ -30,17 +30,17 @@ func (g *GatewayWsServer) OnHandShake(conn socket.ISocketConn) {
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
if oldClient := ws_handler2.UserMgr.GetByUID(int32(t)); oldClient != nil {
|
||||
if oldClient := ws_handler2.UserMgr.GetByUSN(int64(t)); oldClient != nil {
|
||||
oldClient.CloseClient()
|
||||
}
|
||||
client := ws_handler2.NewClient(int32(t), conn)
|
||||
ws_handler2.UserMgr.Add(int32(t), client)
|
||||
client := ws_handler2.NewClient(int64(t), conn)
|
||||
ws_handler2.UserMgr.Add(int64(t), client)
|
||||
conn.SetParam("client", client)
|
||||
}
|
||||
|
||||
func (g *GatewayWsServer) OnMessage(conn socket.ISocketConn, bytes []byte) socket.Action {
|
||||
client, ok := conn.GetParam("client").(*ws_handler2.Client)
|
||||
if !ok || client.UID == 0 {
|
||||
if !ok || client.USN == 0 {
|
||||
return socket.Close
|
||||
}
|
||||
client.OnEvent(&ws_handler2.ClientEvent{Msg: bytes})
|
||||
@@ -49,7 +49,7 @@ func (g *GatewayWsServer) OnMessage(conn socket.ISocketConn, bytes []byte) socke
|
||||
|
||||
func (g *GatewayWsServer) OnPong(conn socket.ISocketConn) {
|
||||
client, ok := conn.GetParam("client").(*ws_handler2.Client)
|
||||
if !ok || client.UID == 0 {
|
||||
if !ok || client.USN == 0 {
|
||||
return
|
||||
}
|
||||
client.OnEvent(&ws_handler2.PongEvent{})
|
||||
|
||||
Reference in New Issue
Block a user