| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132 |
- package middleware
- import (
- "context"
- "net/http"
- "strings"
- "perms-system-server/internal/consts"
- "perms-system-server/internal/loaders"
- "perms-system-server/internal/response"
- "github.com/golang-jwt/jwt/v4"
- "github.com/zeromicro/go-zero/rest/httpx"
- )
- type contextKey string
- const (
- ctxKeyUserDetails contextKey = "userDetails"
- )
- // Claims JWT access token 的 Claims 结构。
- type Claims struct {
- TokenType string `json:"tokenType"`
- UserId int64 `json:"userId"`
- Username string `json:"username"`
- ProductCode string `json:"productCode"`
- MemberType string `json:"memberType"`
- TokenVersion int64 `json:"tokenVersion"`
- Perms []string `json:"perms"`
- jwt.RegisteredClaims
- }
- type JwtAuthMiddleware struct {
- accessSecret string
- loader *loaders.UserDetailsLoader
- }
- func NewJwtAuthMiddleware(accessSecret string, loader *loaders.UserDetailsLoader) *JwtAuthMiddleware {
- return &JwtAuthMiddleware{
- accessSecret: accessSecret,
- loader: loader,
- }
- }
- func (m *JwtAuthMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
- return func(w http.ResponseWriter, r *http.Request) {
- authHeader := r.Header.Get("Authorization")
- if authHeader == "" {
- httpx.ErrorCtx(r.Context(), w, response.NewCodeError(401, "未登录"))
- return
- }
- tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
- if tokenStr == authHeader {
- httpx.ErrorCtx(r.Context(), w, response.NewCodeError(401, "token格式错误"))
- return
- }
- token, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(token *jwt.Token) (interface{}, error) {
- return []byte(m.accessSecret), nil
- })
- if err != nil || !token.Valid {
- httpx.ErrorCtx(r.Context(), w, response.NewCodeError(401, "token无效或已过期"))
- return
- }
- claims, ok := token.Claims.(*Claims)
- if !ok || claims.TokenType != consts.TokenTypeAccess {
- httpx.ErrorCtx(r.Context(), w, response.NewCodeError(401, "token无效或类型错误"))
- return
- }
- ud := m.loader.Load(r.Context(), claims.UserId, claims.ProductCode)
- if ud.Status != consts.StatusEnabled {
- httpx.ErrorCtx(r.Context(), w, response.NewCodeError(403, "账号已被冻结"))
- return
- }
- if claims.TokenVersion != ud.TokenVersion {
- httpx.ErrorCtx(r.Context(), w, response.NewCodeError(401, "登录状态已失效,请重新登录"))
- return
- }
- ctx := context.WithValue(r.Context(), ctxKeyUserDetails, ud)
- next(w, r.WithContext(ctx))
- }
- }
- // -------- context helpers --------
- func WithUserDetails(ctx context.Context, ud *loaders.UserDetails) context.Context {
- return context.WithValue(ctx, ctxKeyUserDetails, ud)
- }
- func GetUserDetails(ctx context.Context) *loaders.UserDetails {
- v, _ := ctx.Value(ctxKeyUserDetails).(*loaders.UserDetails)
- return v
- }
- func GetUserId(ctx context.Context) int64 {
- if ud := GetUserDetails(ctx); ud != nil {
- return ud.UserId
- }
- return 0
- }
- func GetUsername(ctx context.Context) string {
- if ud := GetUserDetails(ctx); ud != nil {
- return ud.Username
- }
- return ""
- }
- func GetProductCode(ctx context.Context) string {
- if ud := GetUserDetails(ctx); ud != nil {
- return ud.ProductCode
- }
- return ""
- }
- func GetMemberType(ctx context.Context) string {
- if ud := GetUserDetails(ctx); ud != nil {
- return ud.MemberType
- }
- return ""
- }
- func IsSuperAdmin(ctx context.Context) bool {
- if ud := GetUserDetails(ctx); ud != nil {
- return ud.IsSuperAdmin
- }
- return false
- }
|