package middleware import ( "context" "fmt" "net/http" "strings" "perms-system-server/internal/consts" "perms-system-server/internal/loaders" "perms-system-server/internal/response" "github.com/golang-jwt/jwt/v4" "github.com/zeromicro/go-zero/rest/httpx" ) type contextKey string const ( ctxKeyUserDetails contextKey = "userDetails" ) // Claims JWT access token 的 Claims 结构。 type Claims struct { TokenType string `json:"tokenType"` UserId int64 `json:"userId"` Username string `json:"username"` ProductCode string `json:"productCode"` MemberType string `json:"memberType"` TokenVersion int64 `json:"tokenVersion"` jwt.RegisteredClaims } // ParseWithHMAC 所有 JWT 解析点(HTTP 中间件 / gRPC VerifyToken / RefreshToken 等) // 的统一入口。必须显式断言 token.Method 为 *jwt.SigningMethodHMAC,避免未来迁移到 RSA/ECDSA // 非对称密钥时把公钥当成 HMAC 共享密钥伪造 token(jwt-go 历史 CVE-2016-10555 同类问题, // OWASP JWT Cheat Sheet / RFC 8725 强制要求 alg 白名单,见审计 H-4 / L-N1)。 // // 函数放在 middleware 包是为了避免 auth → middleware 的循环依赖(auth 包已经引用 // middleware.Claims)。所有历史 inline keyfunc 调用点都应统一替换为本 helper, // 把"算法混淆防御"的审计覆盖矩阵收敛到一个函数。 func ParseWithHMAC(tokenStr, secret string, claims jwt.Claims) (*jwt.Token, error) { return jwt.ParseWithClaims(tokenStr, claims, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return []byte(secret), nil }) } type JwtAuthMiddleware struct { accessSecret string loader *loaders.UserDetailsLoader } func NewJwtAuthMiddleware(accessSecret string, loader *loaders.UserDetailsLoader) *JwtAuthMiddleware { return &JwtAuthMiddleware{ accessSecret: accessSecret, loader: loader, } } func (m *JwtAuthMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { authHeader := r.Header.Get("Authorization") if authHeader == "" { httpx.ErrorCtx(r.Context(), w, response.NewCodeError(401, "未登录")) return } tokenStr := strings.TrimPrefix(authHeader, "Bearer ") if tokenStr == authHeader { httpx.ErrorCtx(r.Context(), w, response.NewCodeError(401, "token格式错误")) return } token, err := ParseWithHMAC(tokenStr, m.accessSecret, &Claims{}) if err != nil || !token.Valid { httpx.ErrorCtx(r.Context(), w, response.NewCodeError(401, "token无效或已过期")) return } claims, ok := token.Claims.(*Claims) if !ok || claims.TokenType != consts.TokenTypeAccess { httpx.ErrorCtx(r.Context(), w, response.NewCodeError(401, "token无效或类型错误")) return } ud, err := m.loader.Load(r.Context(), claims.UserId, claims.ProductCode) if err != nil { // DB / Redis 短时不可用;与"用户不存在(Username=="")"严格区分,避免把一次 DB 抖动同化 // 成"全站用户被删除"把客户端集体 kick 掉形成雪崩(见审计 M-1)。返回 503 让客户端按 // 临时故障重试策略处理,token 不作废。 httpx.ErrorCtx(r.Context(), w, response.NewCodeError(503, "服务暂时不可用,请稍后重试")) return } if ud.Username == "" { httpx.ErrorCtx(r.Context(), w, response.NewCodeError(401, "用户不存在或已被删除")) return } if ud.Status != consts.StatusEnabled { httpx.ErrorCtx(r.Context(), w, response.NewCodeError(403, "账号已被冻结")) return } // 审计 H-R18-2:所在部门已被冻结(DeptStatus=Disabled)时硬拦截所有非超管请求, // 与 UpdateDeptLogic 的 normalDeptFrozen / devFullAccessRevoked 语义闭环—— // 冻结部门 = 冻结部门所有成员所有活动,而不仅是"吊销一次 session"。 // DeptId==0(超管或无部门的历史数据)不命中此分支,避免误伤。 if !ud.IsSuperAdmin && ud.DeptId > 0 && ud.DeptStatus != consts.StatusEnabled { httpx.ErrorCtx(r.Context(), w, response.NewCodeError(403, "所在部门已被冻结")) return } if claims.TokenVersion != ud.TokenVersion { httpx.ErrorCtx(r.Context(), w, response.NewCodeError(401, "登录状态已失效,请重新登录")) return } if claims.ProductCode != "" && ud.ProductStatus != consts.StatusEnabled { httpx.ErrorCtx(r.Context(), w, response.NewCodeError(403, "该产品已被禁用")) return } if claims.ProductCode != "" && !ud.IsSuperAdmin && ud.MemberType == "" { httpx.ErrorCtx(r.Context(), w, response.NewCodeError(403, "您已不是该产品的有效成员")) return } ctx := context.WithValue(r.Context(), ctxKeyUserDetails, ud) next(w, r.WithContext(ctx)) } } // -------- context helpers -------- func WithUserDetails(ctx context.Context, ud *loaders.UserDetails) context.Context { return context.WithValue(ctx, ctxKeyUserDetails, ud) } func GetUserDetails(ctx context.Context) *loaders.UserDetails { v, _ := ctx.Value(ctxKeyUserDetails).(*loaders.UserDetails) return v } func GetUserId(ctx context.Context) int64 { if ud := GetUserDetails(ctx); ud != nil { return ud.UserId } return 0 } func GetProductCode(ctx context.Context) string { if ud := GetUserDetails(ctx); ud != nil { return ud.ProductCode } return "" }