| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339 |
- package middleware_test
- import (
- "context"
- "database/sql"
- "encoding/json"
- "net/http"
- "net/http/httptest"
- "testing"
- "time"
- "perms-system-server/internal/consts"
- "perms-system-server/internal/loaders"
- "perms-system-server/internal/middleware"
- "perms-system-server/internal/model"
- "perms-system-server/internal/model/user"
- "perms-system-server/internal/response"
- "perms-system-server/internal/testutil"
- "github.com/golang-jwt/jwt/v4"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
- "github.com/zeromicro/go-zero/core/stores/redis"
- "github.com/zeromicro/go-zero/rest/httpx"
- )
- const testAccessSecret = "test-middleware-secret"
- func generateTestToken(secret string, expireSeconds int64, claims *middleware.Claims) string {
- now := time.Now()
- claims.RegisteredClaims = jwt.RegisteredClaims{
- ExpiresAt: jwt.NewNumericDate(now.Add(time.Duration(expireSeconds) * time.Second)),
- IssuedAt: jwt.NewNumericDate(now),
- }
- token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
- tokenStr, _ := token.SignedString([]byte(secret))
- return tokenStr
- }
- func newTestMiddleware() (*middleware.JwtAuthMiddleware, *loaders.UserDetailsLoader) {
- cfg := testutil.GetTestConfig()
- conn := testutil.GetTestSqlConn()
- models := model.NewModels(conn, testutil.GetTestCacheConf(), testutil.GetTestCachePrefix())
- rds := redis.MustNewRedis(cfg.CacheRedis.Nodes[0].RedisConf)
- loader := loaders.NewUserDetailsLoader(rds, testutil.GetTestCachePrefix(), models)
- m := middleware.NewJwtAuthMiddleware(testAccessSecret, loader)
- return m, loader
- }
- func createTestUser(t *testing.T, username string) (int64, func()) {
- t.Helper()
- ctx := context.Background()
- conn := testutil.GetTestSqlConn()
- models := model.NewModels(conn, testutil.GetTestCacheConf(), testutil.GetTestCachePrefix())
- now := time.Now().Unix()
- u := &user.SysUser{
- Username: username,
- Password: testutil.HashPassword("test123"),
- Nickname: "test_nick",
- Avatar: sql.NullString{Valid: false},
- Email: "[email protected]",
- Phone: "13800000000",
- Remark: "",
- DeptId: 0,
- IsSuperAdmin: consts.IsSuperAdminNo,
- MustChangePassword: consts.MustChangePasswordNo,
- Status: consts.StatusEnabled,
- CreateTime: now,
- UpdateTime: now,
- }
- result, err := models.SysUserModel.Insert(ctx, u)
- require.NoError(t, err)
- userId, err := result.LastInsertId()
- require.NoError(t, err)
- cleanup := func() {
- testutil.CleanTable(ctx, conn, "sys_user", userId)
- }
- return userId, cleanup
- }
- func init() {
- response.Setup()
- }
- // TC-0258: `Authorization: Bearer {valid}`
- func TestJwtAuthMiddleware_Handle(t *testing.T) {
- m, _ := newTestMiddleware()
- t.Run("valid token", func(t *testing.T) {
- username := "mw_valid_" + testutil.UniqueId()
- userId, cleanup := createTestUser(t, username)
- defer cleanup()
- tokenStr := generateTestToken(testAccessSecret, 3600, &middleware.Claims{
- TokenType: consts.TokenTypeAccess,
- UserId: userId,
- Username: username,
- ProductCode: "",
- })
- var capturedCtx context.Context
- handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
- capturedCtx = r.Context()
- w.WriteHeader(http.StatusOK)
- })
- req := httptest.NewRequest(http.MethodPost, "/test", nil)
- req.Header.Set("Authorization", "Bearer "+tokenStr)
- rr := httptest.NewRecorder()
- handler.ServeHTTP(rr, req)
- assert.Equal(t, http.StatusOK, rr.Code)
- assert.Equal(t, userId, middleware.GetUserId(capturedCtx))
- assert.Equal(t, "", middleware.GetProductCode(capturedCtx))
- details := middleware.GetUserDetails(capturedCtx)
- require.NotNil(t, details)
- assert.Equal(t, username, details.Username)
- assert.Equal(t, "", details.MemberType)
- assert.False(t, details.IsSuperAdmin)
- })
- t.Run("no authorization header", func(t *testing.T) {
- handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
- t.Fatal("should not reach handler")
- })
- req := httptest.NewRequest(http.MethodPost, "/test", nil)
- rr := httptest.NewRecorder()
- handler.ServeHTTP(rr, req)
- var body response.Body
- err := json.Unmarshal(rr.Body.Bytes(), &body)
- require.NoError(t, err)
- assert.Equal(t, 401, body.Code)
- assert.Equal(t, "未登录", body.Msg)
- })
- t.Run("no Bearer prefix", func(t *testing.T) {
- handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
- t.Fatal("should not reach handler")
- })
- req := httptest.NewRequest(http.MethodPost, "/test", nil)
- req.Header.Set("Authorization", "Basic some-token")
- rr := httptest.NewRecorder()
- handler.ServeHTTP(rr, req)
- var body response.Body
- err := json.Unmarshal(rr.Body.Bytes(), &body)
- require.NoError(t, err)
- assert.Equal(t, 401, body.Code)
- assert.Equal(t, "token格式错误", body.Msg)
- })
- t.Run("invalid token", func(t *testing.T) {
- handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
- t.Fatal("should not reach handler")
- })
- req := httptest.NewRequest(http.MethodPost, "/test", nil)
- req.Header.Set("Authorization", "Bearer invalid-token-string")
- rr := httptest.NewRecorder()
- handler.ServeHTTP(rr, req)
- var body response.Body
- err := json.Unmarshal(rr.Body.Bytes(), &body)
- require.NoError(t, err)
- assert.Equal(t, 401, body.Code)
- assert.Equal(t, "token无效或已过期", body.Msg)
- })
- t.Run("wrong secret", func(t *testing.T) {
- tokenStr := generateTestToken("wrong-secret", 3600, &middleware.Claims{
- TokenType: consts.TokenTypeAccess,
- UserId: 1,
- })
- handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
- t.Fatal("should not reach handler")
- })
- req := httptest.NewRequest(http.MethodPost, "/test", nil)
- req.Header.Set("Authorization", "Bearer "+tokenStr)
- rr := httptest.NewRecorder()
- handler.ServeHTTP(rr, req)
- var body response.Body
- err := json.Unmarshal(rr.Body.Bytes(), &body)
- require.NoError(t, err)
- assert.Equal(t, 401, body.Code)
- assert.Equal(t, "token无效或已过期", body.Msg)
- })
- t.Run("expired token", func(t *testing.T) {
- tokenStr := generateTestToken(testAccessSecret, -10, &middleware.Claims{
- TokenType: consts.TokenTypeAccess,
- UserId: 1,
- })
- handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
- t.Fatal("should not reach handler")
- })
- req := httptest.NewRequest(http.MethodPost, "/test", nil)
- req.Header.Set("Authorization", "Bearer "+tokenStr)
- rr := httptest.NewRecorder()
- handler.ServeHTTP(rr, req)
- var body response.Body
- err := json.Unmarshal(rr.Body.Bytes(), &body)
- require.NoError(t, err)
- assert.Equal(t, 401, body.Code)
- assert.Equal(t, "token无效或已过期", body.Msg)
- })
- // TC-0264: refresh token 不应被中间件接受
- t.Run("refresh token rejected", func(t *testing.T) {
- tokenStr := generateTestToken(testAccessSecret, 3600, &middleware.Claims{
- TokenType: consts.TokenTypeRefresh,
- UserId: 100,
- })
- handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
- t.Fatal("should not reach handler")
- })
- req := httptest.NewRequest(http.MethodPost, "/test", nil)
- req.Header.Set("Authorization", "Bearer "+tokenStr)
- rr := httptest.NewRecorder()
- handler.ServeHTTP(rr, req)
- var body response.Body
- err := json.Unmarshal(rr.Body.Bytes(), &body)
- require.NoError(t, err)
- assert.Equal(t, 401, body.Code)
- assert.Equal(t, "token无效或类型错误", body.Msg)
- })
- t.Run("frozen user rejected", func(t *testing.T) {
- username := "mw_frozen_" + testutil.UniqueId()
- userId, cleanup := createTestUser(t, username)
- defer cleanup()
- ctx := context.Background()
- conn := testutil.GetTestSqlConn()
- now := time.Now().Unix()
- models := model.NewModels(conn, testutil.GetTestCacheConf(), testutil.GetTestCachePrefix())
- err := models.SysUserModel.Update(ctx, &user.SysUser{
- Id: userId,
- Username: username,
- Password: testutil.HashPassword("test123"),
- Nickname: "test_nick",
- Avatar: sql.NullString{Valid: false},
- Email: "[email protected]",
- Phone: "13800000000",
- DeptId: 0,
- IsSuperAdmin: consts.IsSuperAdminNo,
- MustChangePassword: consts.MustChangePasswordNo,
- Status: consts.StatusDisabled,
- CreateTime: now,
- UpdateTime: now,
- })
- require.NoError(t, err)
- tokenStr := generateTestToken(testAccessSecret, 3600, &middleware.Claims{
- TokenType: consts.TokenTypeAccess,
- UserId: userId,
- Username: username,
- ProductCode: "",
- })
- handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
- t.Fatal("should not reach handler")
- })
- req := httptest.NewRequest(http.MethodPost, "/test", nil)
- req.Header.Set("Authorization", "Bearer "+tokenStr)
- rr := httptest.NewRecorder()
- handler.ServeHTTP(rr, req)
- var body response.Body
- err = json.Unmarshal(rr.Body.Bytes(), &body)
- require.NoError(t, err)
- assert.Equal(t, 403, body.Code)
- assert.Equal(t, "账号已被冻结", body.Msg)
- })
- }
- // TC-0306: ctx含userId=100
- func TestGetUserId(t *testing.T) {
- ctx := context.Background()
- assert.Equal(t, int64(0), middleware.GetUserId(ctx))
- ctx = middleware.WithUserDetails(ctx, &loaders.UserDetails{UserId: 42})
- assert.Equal(t, int64(42), middleware.GetUserId(ctx))
- ctx2 := context.Background()
- assert.Equal(t, int64(0), middleware.GetUserId(ctx2))
- }
- // TC-0275: 空ctx
- func TestGetProductCode(t *testing.T) {
- ctx := context.Background()
- assert.Equal(t, "", middleware.GetProductCode(ctx))
- ctx = middleware.WithUserDetails(ctx, &loaders.UserDetails{ProductCode: "p1"})
- assert.Equal(t, "p1", middleware.GetProductCode(ctx))
- }
- // TC-0309: GetUserDetails 返回完整用户信息
- func TestGetUserDetails(t *testing.T) {
- ctx := context.Background()
- assert.Nil(t, middleware.GetUserDetails(ctx))
- expected := &loaders.UserDetails{
- UserId: 42,
- Username: "admin",
- ProductCode: "p1",
- MemberType: "ADMIN",
- IsSuperAdmin: true,
- }
- ctx = middleware.WithUserDetails(ctx, expected)
- got := middleware.GetUserDetails(ctx)
- require.NotNil(t, got)
- assert.Equal(t, expected.UserId, got.UserId)
- assert.Equal(t, expected.Username, got.Username)
- assert.Equal(t, expected.ProductCode, got.ProductCode)
- assert.Equal(t, expected.MemberType, got.MemberType)
- assert.Equal(t, expected.IsSuperAdmin, got.IsSuperAdmin)
- }
- // TC-0263: claims类型断言失败(防御性分支)
- // jwt.ParseWithClaims(tokenStr, &Claims{}, keyFunc) 始终将 token.Claims 设为 *Claims,
- // 且解析失败时 Handle 已在 err!=nil 分支提前返回,因此 !ok 分支不可达。
- func TestJwtAuthMiddleware_Handle_ClaimsTypeAssertionUnreachable(t *testing.T) {
- t.Skip("defensive branch: unreachable via jwt.ParseWithClaims — claims is always *Claims")
- }
- // suppress unused import warning for httpx
- var _ = httpx.Error
|