package middleware_test import ( "context" "database/sql" "encoding/json" "net/http" "net/http/httptest" "testing" "time" "perms-system-server/internal/consts" "perms-system-server/internal/loaders" "perms-system-server/internal/middleware" "perms-system-server/internal/model" "perms-system-server/internal/model/user" "perms-system-server/internal/response" "perms-system-server/internal/testutil" "github.com/golang-jwt/jwt/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/zeromicro/go-zero/core/stores/redis" "github.com/zeromicro/go-zero/rest/httpx" ) const testAccessSecret = "test-middleware-secret" func generateTestToken(secret string, expireSeconds int64, claims *middleware.Claims) string { now := time.Now() claims.RegisteredClaims = jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(now.Add(time.Duration(expireSeconds) * time.Second)), IssuedAt: jwt.NewNumericDate(now), } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenStr, _ := token.SignedString([]byte(secret)) return tokenStr } func newTestMiddleware() (*middleware.JwtAuthMiddleware, *loaders.UserDetailsLoader) { cfg := testutil.GetTestConfig() conn := testutil.GetTestSqlConn() models := model.NewModels(conn, testutil.GetTestCacheConf(), testutil.GetTestCachePrefix()) rds := redis.MustNewRedis(cfg.CacheRedis.Nodes[0].RedisConf) loader := loaders.NewUserDetailsLoader(rds, testutil.GetTestCachePrefix(), models) m := middleware.NewJwtAuthMiddleware(testAccessSecret, loader) return m, loader } func createTestUser(t *testing.T, username string) (int64, func()) { t.Helper() ctx := context.Background() conn := testutil.GetTestSqlConn() models := model.NewModels(conn, testutil.GetTestCacheConf(), testutil.GetTestCachePrefix()) now := time.Now().Unix() u := &user.SysUser{ Username: username, Password: testutil.HashPassword("test123"), Nickname: "test_nick", Avatar: sql.NullString{Valid: false}, Email: "test@example.com", Phone: "13800000000", Remark: "", DeptId: 0, IsSuperAdmin: consts.IsSuperAdminNo, MustChangePassword: consts.MustChangePasswordNo, Status: consts.StatusEnabled, CreateTime: now, UpdateTime: now, } result, err := models.SysUserModel.Insert(ctx, u) require.NoError(t, err) userId, err := result.LastInsertId() require.NoError(t, err) cleanup := func() { testutil.CleanTable(ctx, conn, "sys_user", userId) } return userId, cleanup } func init() { response.Setup() } // TC-0258: `Authorization: Bearer {valid}` func TestJwtAuthMiddleware_Handle(t *testing.T) { m, _ := newTestMiddleware() t.Run("valid token", func(t *testing.T) { username := "mw_valid_" + testutil.UniqueId() userId, cleanup := createTestUser(t, username) defer cleanup() tokenStr := generateTestToken(testAccessSecret, 3600, &middleware.Claims{ TokenType: consts.TokenTypeAccess, UserId: userId, Username: username, ProductCode: "", }) var capturedCtx context.Context handler := m.Handle(func(w http.ResponseWriter, r *http.Request) { capturedCtx = r.Context() w.WriteHeader(http.StatusOK) }) req := httptest.NewRequest(http.MethodPost, "/test", nil) req.Header.Set("Authorization", "Bearer "+tokenStr) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusOK, rr.Code) assert.Equal(t, userId, middleware.GetUserId(capturedCtx)) assert.Equal(t, "", middleware.GetProductCode(capturedCtx)) details := middleware.GetUserDetails(capturedCtx) require.NotNil(t, details) assert.Equal(t, username, details.Username) assert.Equal(t, "", details.MemberType) assert.False(t, details.IsSuperAdmin) }) t.Run("no authorization header", func(t *testing.T) { handler := m.Handle(func(w http.ResponseWriter, r *http.Request) { t.Fatal("should not reach handler") }) req := httptest.NewRequest(http.MethodPost, "/test", nil) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) var body response.Body err := json.Unmarshal(rr.Body.Bytes(), &body) require.NoError(t, err) assert.Equal(t, 401, body.Code) assert.Equal(t, "未登录", body.Msg) }) t.Run("no Bearer prefix", func(t *testing.T) { handler := m.Handle(func(w http.ResponseWriter, r *http.Request) { t.Fatal("should not reach handler") }) req := httptest.NewRequest(http.MethodPost, "/test", nil) req.Header.Set("Authorization", "Basic some-token") rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) var body response.Body err := json.Unmarshal(rr.Body.Bytes(), &body) require.NoError(t, err) assert.Equal(t, 401, body.Code) assert.Equal(t, "token格式错误", body.Msg) }) t.Run("invalid token", func(t *testing.T) { handler := m.Handle(func(w http.ResponseWriter, r *http.Request) { t.Fatal("should not reach handler") }) req := httptest.NewRequest(http.MethodPost, "/test", nil) req.Header.Set("Authorization", "Bearer invalid-token-string") rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) var body response.Body err := json.Unmarshal(rr.Body.Bytes(), &body) require.NoError(t, err) assert.Equal(t, 401, body.Code) assert.Equal(t, "token无效或已过期", body.Msg) }) t.Run("wrong secret", func(t *testing.T) { tokenStr := generateTestToken("wrong-secret", 3600, &middleware.Claims{ TokenType: consts.TokenTypeAccess, UserId: 1, }) handler := m.Handle(func(w http.ResponseWriter, r *http.Request) { t.Fatal("should not reach handler") }) req := httptest.NewRequest(http.MethodPost, "/test", nil) req.Header.Set("Authorization", "Bearer "+tokenStr) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) var body response.Body err := json.Unmarshal(rr.Body.Bytes(), &body) require.NoError(t, err) assert.Equal(t, 401, body.Code) assert.Equal(t, "token无效或已过期", body.Msg) }) t.Run("expired token", func(t *testing.T) { tokenStr := generateTestToken(testAccessSecret, -10, &middleware.Claims{ TokenType: consts.TokenTypeAccess, UserId: 1, }) handler := m.Handle(func(w http.ResponseWriter, r *http.Request) { t.Fatal("should not reach handler") }) req := httptest.NewRequest(http.MethodPost, "/test", nil) req.Header.Set("Authorization", "Bearer "+tokenStr) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) var body response.Body err := json.Unmarshal(rr.Body.Bytes(), &body) require.NoError(t, err) assert.Equal(t, 401, body.Code) assert.Equal(t, "token无效或已过期", body.Msg) }) // TC-0264: refresh token 不应被中间件接受 t.Run("refresh token rejected", func(t *testing.T) { tokenStr := generateTestToken(testAccessSecret, 3600, &middleware.Claims{ TokenType: consts.TokenTypeRefresh, UserId: 100, }) handler := m.Handle(func(w http.ResponseWriter, r *http.Request) { t.Fatal("should not reach handler") }) req := httptest.NewRequest(http.MethodPost, "/test", nil) req.Header.Set("Authorization", "Bearer "+tokenStr) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) var body response.Body err := json.Unmarshal(rr.Body.Bytes(), &body) require.NoError(t, err) assert.Equal(t, 401, body.Code) assert.Equal(t, "token无效或类型错误", body.Msg) }) t.Run("frozen user rejected", func(t *testing.T) { username := "mw_frozen_" + testutil.UniqueId() userId, cleanup := createTestUser(t, username) defer cleanup() ctx := context.Background() conn := testutil.GetTestSqlConn() now := time.Now().Unix() models := model.NewModels(conn, testutil.GetTestCacheConf(), testutil.GetTestCachePrefix()) err := models.SysUserModel.Update(ctx, &user.SysUser{ Id: userId, Username: username, Password: testutil.HashPassword("test123"), Nickname: "test_nick", Avatar: sql.NullString{Valid: false}, Email: "test@example.com", Phone: "13800000000", DeptId: 0, IsSuperAdmin: consts.IsSuperAdminNo, MustChangePassword: consts.MustChangePasswordNo, Status: consts.StatusDisabled, CreateTime: now, UpdateTime: now, }) require.NoError(t, err) tokenStr := generateTestToken(testAccessSecret, 3600, &middleware.Claims{ TokenType: consts.TokenTypeAccess, UserId: userId, Username: username, ProductCode: "", }) handler := m.Handle(func(w http.ResponseWriter, r *http.Request) { t.Fatal("should not reach handler") }) req := httptest.NewRequest(http.MethodPost, "/test", nil) req.Header.Set("Authorization", "Bearer "+tokenStr) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) var body response.Body err = json.Unmarshal(rr.Body.Bytes(), &body) require.NoError(t, err) assert.Equal(t, 403, body.Code) assert.Equal(t, "账号已被冻结", body.Msg) }) } // TC-0306: ctx含userId=100 func TestGetUserId(t *testing.T) { ctx := context.Background() assert.Equal(t, int64(0), middleware.GetUserId(ctx)) ctx = middleware.WithUserDetails(ctx, &loaders.UserDetails{UserId: 42}) assert.Equal(t, int64(42), middleware.GetUserId(ctx)) ctx2 := context.Background() assert.Equal(t, int64(0), middleware.GetUserId(ctx2)) } // TC-0275: 空ctx func TestGetProductCode(t *testing.T) { ctx := context.Background() assert.Equal(t, "", middleware.GetProductCode(ctx)) ctx = middleware.WithUserDetails(ctx, &loaders.UserDetails{ProductCode: "p1"}) assert.Equal(t, "p1", middleware.GetProductCode(ctx)) } // TC-0309: GetUserDetails 返回完整用户信息 func TestGetUserDetails(t *testing.T) { ctx := context.Background() assert.Nil(t, middleware.GetUserDetails(ctx)) expected := &loaders.UserDetails{ UserId: 42, Username: "admin", ProductCode: "p1", MemberType: "ADMIN", IsSuperAdmin: true, } ctx = middleware.WithUserDetails(ctx, expected) got := middleware.GetUserDetails(ctx) require.NotNil(t, got) assert.Equal(t, expected.UserId, got.UserId) assert.Equal(t, expected.Username, got.Username) assert.Equal(t, expected.ProductCode, got.ProductCode) assert.Equal(t, expected.MemberType, got.MemberType) assert.Equal(t, expected.IsSuperAdmin, got.IsSuperAdmin) } // TC-0263: claims类型断言失败(防御性分支) // jwt.ParseWithClaims(tokenStr, &Claims{}, keyFunc) 始终将 token.Claims 设为 *Claims, // 且解析失败时 Handle 已在 err!=nil 分支提前返回,因此 !ok 分支不可达。 func TestJwtAuthMiddleware_Handle_ClaimsTypeAssertionUnreachable(t *testing.T) { t.Skip("defensive branch: unreachable via jwt.ParseWithClaims — claims is always *Claims") } // suppress unused import warning for httpx var _ = httpx.Error