jwtauthMiddleware_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. package middleware_test
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "net/http"
  7. "net/http/httptest"
  8. "testing"
  9. "time"
  10. "perms-system-server/internal/consts"
  11. "perms-system-server/internal/loaders"
  12. "perms-system-server/internal/middleware"
  13. "perms-system-server/internal/model"
  14. "perms-system-server/internal/model/user"
  15. "perms-system-server/internal/response"
  16. "perms-system-server/internal/testutil"
  17. "github.com/golang-jwt/jwt/v4"
  18. "github.com/stretchr/testify/assert"
  19. "github.com/stretchr/testify/require"
  20. "github.com/zeromicro/go-zero/core/stores/redis"
  21. "github.com/zeromicro/go-zero/rest/httpx"
  22. )
  23. const testAccessSecret = "test-middleware-secret"
  24. func generateTestToken(secret string, expireSeconds int64, claims *middleware.Claims) string {
  25. now := time.Now()
  26. claims.RegisteredClaims = jwt.RegisteredClaims{
  27. ExpiresAt: jwt.NewNumericDate(now.Add(time.Duration(expireSeconds) * time.Second)),
  28. IssuedAt: jwt.NewNumericDate(now),
  29. }
  30. token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
  31. tokenStr, _ := token.SignedString([]byte(secret))
  32. return tokenStr
  33. }
  34. func newTestMiddleware() (*middleware.JwtAuthMiddleware, *loaders.UserDetailsLoader) {
  35. cfg := testutil.GetTestConfig()
  36. conn := testutil.GetTestSqlConn()
  37. models := model.NewModels(conn, testutil.GetTestCacheConf(), testutil.GetTestCachePrefix())
  38. rds := redis.MustNewRedis(cfg.CacheRedis.Nodes[0].RedisConf)
  39. loader := loaders.NewUserDetailsLoader(rds, testutil.GetTestCachePrefix(), models)
  40. m := middleware.NewJwtAuthMiddleware(testAccessSecret, loader)
  41. return m, loader
  42. }
  43. func createTestUser(t *testing.T, username string) (int64, func()) {
  44. t.Helper()
  45. ctx := context.Background()
  46. conn := testutil.GetTestSqlConn()
  47. models := model.NewModels(conn, testutil.GetTestCacheConf(), testutil.GetTestCachePrefix())
  48. now := time.Now().Unix()
  49. u := &user.SysUser{
  50. Username: username,
  51. Password: testutil.HashPassword("test123"),
  52. Nickname: "test_nick",
  53. Avatar: sql.NullString{Valid: false},
  54. Email: "[email protected]",
  55. Phone: "13800000000",
  56. Remark: "",
  57. DeptId: 0,
  58. IsSuperAdmin: consts.IsSuperAdminNo,
  59. MustChangePassword: consts.MustChangePasswordNo,
  60. Status: consts.StatusEnabled,
  61. CreateTime: now,
  62. UpdateTime: now,
  63. }
  64. result, err := models.SysUserModel.Insert(ctx, u)
  65. require.NoError(t, err)
  66. userId, err := result.LastInsertId()
  67. require.NoError(t, err)
  68. cleanup := func() {
  69. testutil.CleanTable(ctx, conn, "sys_user", userId)
  70. }
  71. return userId, cleanup
  72. }
  73. func init() {
  74. response.Setup()
  75. }
  76. // TC-0258: `Authorization: Bearer {valid}`
  77. func TestJwtAuthMiddleware_Handle(t *testing.T) {
  78. m, _ := newTestMiddleware()
  79. t.Run("valid token", func(t *testing.T) {
  80. username := "mw_valid_" + testutil.UniqueId()
  81. userId, cleanup := createTestUser(t, username)
  82. defer cleanup()
  83. tokenStr := generateTestToken(testAccessSecret, 3600, &middleware.Claims{
  84. TokenType: consts.TokenTypeAccess,
  85. UserId: userId,
  86. Username: username,
  87. ProductCode: "",
  88. })
  89. var capturedCtx context.Context
  90. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  91. capturedCtx = r.Context()
  92. w.WriteHeader(http.StatusOK)
  93. })
  94. req := httptest.NewRequest(http.MethodPost, "/test", nil)
  95. req.Header.Set("Authorization", "Bearer "+tokenStr)
  96. rr := httptest.NewRecorder()
  97. handler.ServeHTTP(rr, req)
  98. assert.Equal(t, http.StatusOK, rr.Code)
  99. assert.Equal(t, userId, middleware.GetUserId(capturedCtx))
  100. assert.Equal(t, "", middleware.GetProductCode(capturedCtx))
  101. details := middleware.GetUserDetails(capturedCtx)
  102. require.NotNil(t, details)
  103. assert.Equal(t, username, details.Username)
  104. assert.Equal(t, "", details.MemberType)
  105. assert.False(t, details.IsSuperAdmin)
  106. })
  107. t.Run("no authorization header", func(t *testing.T) {
  108. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  109. t.Fatal("should not reach handler")
  110. })
  111. req := httptest.NewRequest(http.MethodPost, "/test", nil)
  112. rr := httptest.NewRecorder()
  113. handler.ServeHTTP(rr, req)
  114. var body response.Body
  115. err := json.Unmarshal(rr.Body.Bytes(), &body)
  116. require.NoError(t, err)
  117. assert.Equal(t, 401, body.Code)
  118. assert.Equal(t, "未登录", body.Msg)
  119. })
  120. t.Run("no Bearer prefix", func(t *testing.T) {
  121. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  122. t.Fatal("should not reach handler")
  123. })
  124. req := httptest.NewRequest(http.MethodPost, "/test", nil)
  125. req.Header.Set("Authorization", "Basic some-token")
  126. rr := httptest.NewRecorder()
  127. handler.ServeHTTP(rr, req)
  128. var body response.Body
  129. err := json.Unmarshal(rr.Body.Bytes(), &body)
  130. require.NoError(t, err)
  131. assert.Equal(t, 401, body.Code)
  132. assert.Equal(t, "token格式错误", body.Msg)
  133. })
  134. t.Run("invalid token", func(t *testing.T) {
  135. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  136. t.Fatal("should not reach handler")
  137. })
  138. req := httptest.NewRequest(http.MethodPost, "/test", nil)
  139. req.Header.Set("Authorization", "Bearer invalid-token-string")
  140. rr := httptest.NewRecorder()
  141. handler.ServeHTTP(rr, req)
  142. var body response.Body
  143. err := json.Unmarshal(rr.Body.Bytes(), &body)
  144. require.NoError(t, err)
  145. assert.Equal(t, 401, body.Code)
  146. assert.Equal(t, "token无效或已过期", body.Msg)
  147. })
  148. t.Run("wrong secret", func(t *testing.T) {
  149. tokenStr := generateTestToken("wrong-secret", 3600, &middleware.Claims{
  150. TokenType: consts.TokenTypeAccess,
  151. UserId: 1,
  152. })
  153. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  154. t.Fatal("should not reach handler")
  155. })
  156. req := httptest.NewRequest(http.MethodPost, "/test", nil)
  157. req.Header.Set("Authorization", "Bearer "+tokenStr)
  158. rr := httptest.NewRecorder()
  159. handler.ServeHTTP(rr, req)
  160. var body response.Body
  161. err := json.Unmarshal(rr.Body.Bytes(), &body)
  162. require.NoError(t, err)
  163. assert.Equal(t, 401, body.Code)
  164. assert.Equal(t, "token无效或已过期", body.Msg)
  165. })
  166. t.Run("expired token", func(t *testing.T) {
  167. tokenStr := generateTestToken(testAccessSecret, -10, &middleware.Claims{
  168. TokenType: consts.TokenTypeAccess,
  169. UserId: 1,
  170. })
  171. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  172. t.Fatal("should not reach handler")
  173. })
  174. req := httptest.NewRequest(http.MethodPost, "/test", nil)
  175. req.Header.Set("Authorization", "Bearer "+tokenStr)
  176. rr := httptest.NewRecorder()
  177. handler.ServeHTTP(rr, req)
  178. var body response.Body
  179. err := json.Unmarshal(rr.Body.Bytes(), &body)
  180. require.NoError(t, err)
  181. assert.Equal(t, 401, body.Code)
  182. assert.Equal(t, "token无效或已过期", body.Msg)
  183. })
  184. // TC-0264: refresh token 不应被中间件接受
  185. t.Run("refresh token rejected", func(t *testing.T) {
  186. tokenStr := generateTestToken(testAccessSecret, 3600, &middleware.Claims{
  187. TokenType: consts.TokenTypeRefresh,
  188. UserId: 100,
  189. })
  190. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  191. t.Fatal("should not reach handler")
  192. })
  193. req := httptest.NewRequest(http.MethodPost, "/test", nil)
  194. req.Header.Set("Authorization", "Bearer "+tokenStr)
  195. rr := httptest.NewRecorder()
  196. handler.ServeHTTP(rr, req)
  197. var body response.Body
  198. err := json.Unmarshal(rr.Body.Bytes(), &body)
  199. require.NoError(t, err)
  200. assert.Equal(t, 401, body.Code)
  201. assert.Equal(t, "token无效或类型错误", body.Msg)
  202. })
  203. t.Run("frozen user rejected", func(t *testing.T) {
  204. username := "mw_frozen_" + testutil.UniqueId()
  205. userId, cleanup := createTestUser(t, username)
  206. defer cleanup()
  207. ctx := context.Background()
  208. conn := testutil.GetTestSqlConn()
  209. now := time.Now().Unix()
  210. models := model.NewModels(conn, testutil.GetTestCacheConf(), testutil.GetTestCachePrefix())
  211. err := models.SysUserModel.Update(ctx, &user.SysUser{
  212. Id: userId,
  213. Username: username,
  214. Password: testutil.HashPassword("test123"),
  215. Nickname: "test_nick",
  216. Avatar: sql.NullString{Valid: false},
  217. Email: "[email protected]",
  218. Phone: "13800000000",
  219. DeptId: 0,
  220. IsSuperAdmin: consts.IsSuperAdminNo,
  221. MustChangePassword: consts.MustChangePasswordNo,
  222. Status: consts.StatusDisabled,
  223. CreateTime: now,
  224. UpdateTime: now,
  225. })
  226. require.NoError(t, err)
  227. tokenStr := generateTestToken(testAccessSecret, 3600, &middleware.Claims{
  228. TokenType: consts.TokenTypeAccess,
  229. UserId: userId,
  230. Username: username,
  231. ProductCode: "",
  232. })
  233. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  234. t.Fatal("should not reach handler")
  235. })
  236. req := httptest.NewRequest(http.MethodPost, "/test", nil)
  237. req.Header.Set("Authorization", "Bearer "+tokenStr)
  238. rr := httptest.NewRecorder()
  239. handler.ServeHTTP(rr, req)
  240. var body response.Body
  241. err = json.Unmarshal(rr.Body.Bytes(), &body)
  242. require.NoError(t, err)
  243. assert.Equal(t, 403, body.Code)
  244. assert.Equal(t, "账号已被冻结", body.Msg)
  245. })
  246. }
  247. // TC-0306: ctx含userId=100
  248. func TestGetUserId(t *testing.T) {
  249. ctx := context.Background()
  250. assert.Equal(t, int64(0), middleware.GetUserId(ctx))
  251. ctx = middleware.WithUserDetails(ctx, &loaders.UserDetails{UserId: 42})
  252. assert.Equal(t, int64(42), middleware.GetUserId(ctx))
  253. ctx2 := context.Background()
  254. assert.Equal(t, int64(0), middleware.GetUserId(ctx2))
  255. }
  256. // TC-0275: 空ctx
  257. func TestGetProductCode(t *testing.T) {
  258. ctx := context.Background()
  259. assert.Equal(t, "", middleware.GetProductCode(ctx))
  260. ctx = middleware.WithUserDetails(ctx, &loaders.UserDetails{ProductCode: "p1"})
  261. assert.Equal(t, "p1", middleware.GetProductCode(ctx))
  262. }
  263. // TC-0309: GetUserDetails 返回完整用户信息
  264. func TestGetUserDetails(t *testing.T) {
  265. ctx := context.Background()
  266. assert.Nil(t, middleware.GetUserDetails(ctx))
  267. expected := &loaders.UserDetails{
  268. UserId: 42,
  269. Username: "admin",
  270. ProductCode: "p1",
  271. MemberType: "ADMIN",
  272. IsSuperAdmin: true,
  273. }
  274. ctx = middleware.WithUserDetails(ctx, expected)
  275. got := middleware.GetUserDetails(ctx)
  276. require.NotNil(t, got)
  277. assert.Equal(t, expected.UserId, got.UserId)
  278. assert.Equal(t, expected.Username, got.Username)
  279. assert.Equal(t, expected.ProductCode, got.ProductCode)
  280. assert.Equal(t, expected.MemberType, got.MemberType)
  281. assert.Equal(t, expected.IsSuperAdmin, got.IsSuperAdmin)
  282. }
  283. // TC-0263: claims类型断言失败(防御性分支)
  284. // jwt.ParseWithClaims(tokenStr, &Claims{}, keyFunc) 始终将 token.Claims 设为 *Claims,
  285. // 且解析失败时 Handle 已在 err!=nil 分支提前返回,因此 !ok 分支不可达。
  286. func TestJwtAuthMiddleware_Handle_ClaimsTypeAssertionUnreachable(t *testing.T) {
  287. t.Skip("defensive branch: unreachable via jwt.ParseWithClaims — claims is always *Claims")
  288. }
  289. // suppress unused import warning for httpx
  290. var _ = httpx.Error