package middleware_test import ( "context" "database/sql" "encoding/json" "net/http" "net/http/httptest" "testing" "time" "perms-system-server/internal/consts" "perms-system-server/internal/loaders" "perms-system-server/internal/middleware" "perms-system-server/internal/model" "perms-system-server/internal/model/user" "perms-system-server/internal/response" "perms-system-server/internal/testutil" "github.com/golang-jwt/jwt/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/zeromicro/go-zero/core/stores/redis" "github.com/zeromicro/go-zero/rest/httpx" ) const testAccessSecret = "test-middleware-secret" func generateTestToken(secret string, expireSeconds int64, claims *middleware.Claims) string { now := time.Now() claims.RegisteredClaims = jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(now.Add(time.Duration(expireSeconds) * time.Second)), IssuedAt: jwt.NewNumericDate(now), } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenStr, _ := token.SignedString([]byte(secret)) return tokenStr } func newTestMiddleware() (*middleware.JwtAuthMiddleware, *loaders.UserDetailsLoader) { cfg := testutil.GetTestConfig() conn := testutil.GetTestSqlConn() models := model.NewModels(conn, testutil.GetTestCacheConf(), testutil.GetTestCachePrefix()) rds := redis.MustNewRedis(cfg.CacheRedis.Nodes[0].RedisConf) loader := loaders.NewUserDetailsLoader(rds, testutil.GetTestCachePrefix(), models) m := middleware.NewJwtAuthMiddleware(testAccessSecret, loader) return m, loader } func createTestUser(t *testing.T, username string) (int64, func()) { t.Helper() ctx := context.Background() conn := testutil.GetTestSqlConn() models := model.NewModels(conn, testutil.GetTestCacheConf(), testutil.GetTestCachePrefix()) now := time.Now().Unix() u := &user.SysUser{ Username: username, Password: testutil.HashPassword("test123"), Nickname: "test_nick", Avatar: sql.NullString{Valid: false}, Email: "test@example.com", Phone: "13800000000", Remark: "", DeptId: 0, IsSuperAdmin: consts.IsSuperAdminNo, MustChangePassword: consts.MustChangePasswordNo, Status: consts.StatusEnabled, CreateTime: now, UpdateTime: now, } result, err := models.SysUserModel.Insert(ctx, u) require.NoError(t, err) userId, err := result.LastInsertId() require.NoError(t, err) cleanup := func() { testutil.CleanTable(ctx, conn, "sys_user", userId) } return userId, cleanup } func init() { response.Setup() } // TC-0223: `Authorization: Bearer {valid}` func TestJwtAuthMiddleware_Handle(t *testing.T) { m, _ := newTestMiddleware() t.Run("valid token", func(t *testing.T) { username := "mw_valid_" + testutil.UniqueId() userId, cleanup := createTestUser(t, username) defer cleanup() tokenStr := generateTestToken(testAccessSecret, 3600, &middleware.Claims{ TokenType: consts.TokenTypeAccess, UserId: userId, Username: username, ProductCode: "", }) var capturedCtx context.Context handler := m.Handle(func(w http.ResponseWriter, r *http.Request) { capturedCtx = r.Context() w.WriteHeader(http.StatusOK) }) req := httptest.NewRequest(http.MethodPost, "/test", nil) req.Header.Set("Authorization", "Bearer "+tokenStr) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusOK, rr.Code) assert.Equal(t, userId, middleware.GetUserId(capturedCtx)) assert.Equal(t, username, middleware.GetUsername(capturedCtx)) assert.Equal(t, "", middleware.GetProductCode(capturedCtx)) assert.Equal(t, "", middleware.GetMemberType(capturedCtx)) assert.False(t, middleware.IsSuperAdmin(capturedCtx)) }) t.Run("no authorization header", func(t *testing.T) { handler := m.Handle(func(w http.ResponseWriter, r *http.Request) { t.Fatal("should not reach handler") }) req := httptest.NewRequest(http.MethodPost, "/test", nil) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) var body response.Body err := json.Unmarshal(rr.Body.Bytes(), &body) require.NoError(t, err) assert.Equal(t, 401, body.Code) assert.Equal(t, "未登录", body.Msg) }) t.Run("no Bearer prefix", func(t *testing.T) { handler := m.Handle(func(w http.ResponseWriter, r *http.Request) { t.Fatal("should not reach handler") }) req := httptest.NewRequest(http.MethodPost, "/test", nil) req.Header.Set("Authorization", "Basic some-token") rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) var body response.Body err := json.Unmarshal(rr.Body.Bytes(), &body) require.NoError(t, err) assert.Equal(t, 401, body.Code) assert.Equal(t, "token格式错误", body.Msg) }) t.Run("invalid token", func(t *testing.T) { handler := m.Handle(func(w http.ResponseWriter, r *http.Request) { t.Fatal("should not reach handler") }) req := httptest.NewRequest(http.MethodPost, "/test", nil) req.Header.Set("Authorization", "Bearer invalid-token-string") rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) var body response.Body err := json.Unmarshal(rr.Body.Bytes(), &body) require.NoError(t, err) assert.Equal(t, 401, body.Code) assert.Equal(t, "token无效或已过期", body.Msg) }) t.Run("wrong secret", func(t *testing.T) { tokenStr := generateTestToken("wrong-secret", 3600, &middleware.Claims{ TokenType: consts.TokenTypeAccess, UserId: 1, }) handler := m.Handle(func(w http.ResponseWriter, r *http.Request) { t.Fatal("should not reach handler") }) req := httptest.NewRequest(http.MethodPost, "/test", nil) req.Header.Set("Authorization", "Bearer "+tokenStr) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) var body response.Body err := json.Unmarshal(rr.Body.Bytes(), &body) require.NoError(t, err) assert.Equal(t, 401, body.Code) assert.Equal(t, "token无效或已过期", body.Msg) }) t.Run("expired token", func(t *testing.T) { tokenStr := generateTestToken(testAccessSecret, -10, &middleware.Claims{ TokenType: consts.TokenTypeAccess, UserId: 1, }) handler := m.Handle(func(w http.ResponseWriter, r *http.Request) { t.Fatal("should not reach handler") }) req := httptest.NewRequest(http.MethodPost, "/test", nil) req.Header.Set("Authorization", "Bearer "+tokenStr) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) var body response.Body err := json.Unmarshal(rr.Body.Bytes(), &body) require.NoError(t, err) assert.Equal(t, 401, body.Code) assert.Equal(t, "token无效或已过期", body.Msg) }) // TC-0229: refresh token 不应被中间件接受 t.Run("refresh token rejected", func(t *testing.T) { tokenStr := generateTestToken(testAccessSecret, 3600, &middleware.Claims{ TokenType: consts.TokenTypeRefresh, UserId: 100, }) handler := m.Handle(func(w http.ResponseWriter, r *http.Request) { t.Fatal("should not reach handler") }) req := httptest.NewRequest(http.MethodPost, "/test", nil) req.Header.Set("Authorization", "Bearer "+tokenStr) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) var body response.Body err := json.Unmarshal(rr.Body.Bytes(), &body) require.NoError(t, err) assert.Equal(t, 401, body.Code) assert.Equal(t, "token无效或类型错误", body.Msg) }) t.Run("frozen user rejected", func(t *testing.T) { username := "mw_frozen_" + testutil.UniqueId() userId, cleanup := createTestUser(t, username) defer cleanup() ctx := context.Background() conn := testutil.GetTestSqlConn() now := time.Now().Unix() models := model.NewModels(conn, testutil.GetTestCacheConf(), testutil.GetTestCachePrefix()) err := models.SysUserModel.Update(ctx, &user.SysUser{ Id: userId, Username: username, Password: testutil.HashPassword("test123"), Nickname: "test_nick", Avatar: sql.NullString{Valid: false}, Email: "test@example.com", Phone: "13800000000", DeptId: 0, IsSuperAdmin: consts.IsSuperAdminNo, MustChangePassword: consts.MustChangePasswordNo, Status: consts.StatusDisabled, CreateTime: now, UpdateTime: now, }) require.NoError(t, err) tokenStr := generateTestToken(testAccessSecret, 3600, &middleware.Claims{ TokenType: consts.TokenTypeAccess, UserId: userId, Username: username, ProductCode: "", }) handler := m.Handle(func(w http.ResponseWriter, r *http.Request) { t.Fatal("should not reach handler") }) req := httptest.NewRequest(http.MethodPost, "/test", nil) req.Header.Set("Authorization", "Bearer "+tokenStr) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) var body response.Body err = json.Unmarshal(rr.Body.Bytes(), &body) require.NoError(t, err) assert.Equal(t, 403, body.Code) assert.Equal(t, "账号已被冻结", body.Msg) }) } // TC-0272: ctx含userId=100 func TestGetUserId(t *testing.T) { ctx := context.Background() assert.Equal(t, int64(0), middleware.GetUserId(ctx)) ctx = middleware.WithUserDetails(ctx, &loaders.UserDetails{UserId: 42}) assert.Equal(t, int64(42), middleware.GetUserId(ctx)) ctx2 := context.Background() assert.Equal(t, int64(0), middleware.GetUserId(ctx2)) } // TC-0274: ctx含username="admin" func TestGetUsername(t *testing.T) { ctx := context.Background() assert.Equal(t, "", middleware.GetUsername(ctx)) ctx = middleware.WithUserDetails(ctx, &loaders.UserDetails{Username: "admin"}) assert.Equal(t, "admin", middleware.GetUsername(ctx)) } // TC-0275: 空ctx func TestGetProductCode(t *testing.T) { ctx := context.Background() assert.Equal(t, "", middleware.GetProductCode(ctx)) ctx = middleware.WithUserDetails(ctx, &loaders.UserDetails{ProductCode: "p1"}) assert.Equal(t, "p1", middleware.GetProductCode(ctx)) } // TC-0276: ctx含productCode="p1" func TestGetMemberType(t *testing.T) { ctx := context.Background() assert.Equal(t, "", middleware.GetMemberType(ctx)) ctx = middleware.WithUserDetails(ctx, &loaders.UserDetails{MemberType: "ADMIN"}) assert.Equal(t, "ADMIN", middleware.GetMemberType(ctx)) } // TC-0277: ctx含memberType="ADMIN" func TestIsSuperAdmin(t *testing.T) { tests := []struct { name string isSuperAdmin bool want bool }{ {"is super admin", true, true}, {"is not super admin", false, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := middleware.WithUserDetails(context.Background(), &loaders.UserDetails{IsSuperAdmin: tt.isSuperAdmin}) assert.Equal(t, tt.want, middleware.IsSuperAdmin(ctx)) }) } t.Run("empty context", func(t *testing.T) { assert.False(t, middleware.IsSuperAdmin(context.Background())) }) } // TC-0228: claims类型断言失败(防御性分支) // jwt.ParseWithClaims(tokenStr, &Claims{}, keyFunc) 始终将 token.Claims 设为 *Claims, // 且解析失败时 Handle 已在 err!=nil 分支提前返回,因此 !ok 分支不可达。 func TestJwtAuthMiddleware_Handle_ClaimsTypeAssertionUnreachable(t *testing.T) { t.Skip("defensive branch: unreachable via jwt.ParseWithClaims — claims is always *Claims") } // suppress unused import warning for httpx var _ = httpx.Error