jwtauthMiddleware.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. package middleware
  2. import (
  3. "context"
  4. "fmt"
  5. "net/http"
  6. "strings"
  7. "perms-system-server/internal/consts"
  8. "perms-system-server/internal/loaders"
  9. "perms-system-server/internal/response"
  10. "github.com/golang-jwt/jwt/v4"
  11. "github.com/zeromicro/go-zero/rest/httpx"
  12. )
  13. type contextKey string
  14. const (
  15. ctxKeyUserDetails contextKey = "userDetails"
  16. )
  17. // Claims JWT access token 的 Claims 结构。
  18. type Claims struct {
  19. TokenType string `json:"tokenType"`
  20. UserId int64 `json:"userId"`
  21. Username string `json:"username"`
  22. ProductCode string `json:"productCode"`
  23. MemberType string `json:"memberType"`
  24. TokenVersion int64 `json:"tokenVersion"`
  25. jwt.RegisteredClaims
  26. }
  27. type JwtAuthMiddleware struct {
  28. accessSecret string
  29. loader *loaders.UserDetailsLoader
  30. }
  31. func NewJwtAuthMiddleware(accessSecret string, loader *loaders.UserDetailsLoader) *JwtAuthMiddleware {
  32. return &JwtAuthMiddleware{
  33. accessSecret: accessSecret,
  34. loader: loader,
  35. }
  36. }
  37. func (m *JwtAuthMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
  38. return func(w http.ResponseWriter, r *http.Request) {
  39. authHeader := r.Header.Get("Authorization")
  40. if authHeader == "" {
  41. httpx.ErrorCtx(r.Context(), w, response.NewCodeError(401, "未登录"))
  42. return
  43. }
  44. tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
  45. if tokenStr == authHeader {
  46. httpx.ErrorCtx(r.Context(), w, response.NewCodeError(401, "token格式错误"))
  47. return
  48. }
  49. // 显式断言 HMAC 签名算法,避免 RSA/ECDSA 公钥被当 HMAC 共享密钥伪造 token
  50. // (审计 H-4 / RFC 8725)。
  51. token, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(token *jwt.Token) (interface{}, error) {
  52. if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
  53. return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
  54. }
  55. return []byte(m.accessSecret), nil
  56. })
  57. if err != nil || !token.Valid {
  58. httpx.ErrorCtx(r.Context(), w, response.NewCodeError(401, "token无效或已过期"))
  59. return
  60. }
  61. claims, ok := token.Claims.(*Claims)
  62. if !ok || claims.TokenType != consts.TokenTypeAccess {
  63. httpx.ErrorCtx(r.Context(), w, response.NewCodeError(401, "token无效或类型错误"))
  64. return
  65. }
  66. ud, err := m.loader.Load(r.Context(), claims.UserId, claims.ProductCode)
  67. if err != nil {
  68. // DB / Redis 短时不可用;与"用户不存在(Username=="")"严格区分,避免把一次 DB 抖动同化
  69. // 成"全站用户被删除"把客户端集体 kick 掉形成雪崩(见审计 M-1)。返回 503 让客户端按
  70. // 临时故障重试策略处理,token 不作废。
  71. httpx.ErrorCtx(r.Context(), w, response.NewCodeError(503, "服务暂时不可用,请稍后重试"))
  72. return
  73. }
  74. if ud.Username == "" {
  75. httpx.ErrorCtx(r.Context(), w, response.NewCodeError(401, "用户不存在或已被删除"))
  76. return
  77. }
  78. if ud.Status != consts.StatusEnabled {
  79. httpx.ErrorCtx(r.Context(), w, response.NewCodeError(403, "账号已被冻结"))
  80. return
  81. }
  82. if claims.TokenVersion != ud.TokenVersion {
  83. httpx.ErrorCtx(r.Context(), w, response.NewCodeError(401, "登录状态已失效,请重新登录"))
  84. return
  85. }
  86. if claims.ProductCode != "" && ud.ProductStatus != consts.StatusEnabled {
  87. httpx.ErrorCtx(r.Context(), w, response.NewCodeError(403, "该产品已被禁用"))
  88. return
  89. }
  90. if claims.ProductCode != "" && !ud.IsSuperAdmin && ud.MemberType == "" {
  91. httpx.ErrorCtx(r.Context(), w, response.NewCodeError(403, "您已不是该产品的有效成员"))
  92. return
  93. }
  94. ctx := context.WithValue(r.Context(), ctxKeyUserDetails, ud)
  95. next(w, r.WithContext(ctx))
  96. }
  97. }
  98. // -------- context helpers --------
  99. func WithUserDetails(ctx context.Context, ud *loaders.UserDetails) context.Context {
  100. return context.WithValue(ctx, ctxKeyUserDetails, ud)
  101. }
  102. func GetUserDetails(ctx context.Context) *loaders.UserDetails {
  103. v, _ := ctx.Value(ctxKeyUserDetails).(*loaders.UserDetails)
  104. return v
  105. }
  106. func GetUserId(ctx context.Context) int64 {
  107. if ud := GetUserDetails(ctx); ud != nil {
  108. return ud.UserId
  109. }
  110. return 0
  111. }
  112. func GetProductCode(ctx context.Context) string {
  113. if ud := GetUserDetails(ctx); ud != nil {
  114. return ud.ProductCode
  115. }
  116. return ""
  117. }