jwtauthMiddleware_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. package middleware_test
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "net/http"
  7. "net/http/httptest"
  8. "testing"
  9. "time"
  10. "perms-system-server/internal/consts"
  11. "perms-system-server/internal/loaders"
  12. "perms-system-server/internal/middleware"
  13. "perms-system-server/internal/model"
  14. "perms-system-server/internal/model/user"
  15. "perms-system-server/internal/response"
  16. "perms-system-server/internal/testutil"
  17. "github.com/golang-jwt/jwt/v4"
  18. "github.com/stretchr/testify/assert"
  19. "github.com/stretchr/testify/require"
  20. "github.com/zeromicro/go-zero/core/stores/redis"
  21. "github.com/zeromicro/go-zero/rest/httpx"
  22. )
  23. const testAccessSecret = "test-middleware-secret"
  24. func generateTestToken(secret string, expireSeconds int64, claims *middleware.Claims) string {
  25. now := time.Now()
  26. claims.RegisteredClaims = jwt.RegisteredClaims{
  27. ExpiresAt: jwt.NewNumericDate(now.Add(time.Duration(expireSeconds) * time.Second)),
  28. IssuedAt: jwt.NewNumericDate(now),
  29. }
  30. token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
  31. tokenStr, _ := token.SignedString([]byte(secret))
  32. return tokenStr
  33. }
  34. func newTestMiddleware() (*middleware.JwtAuthMiddleware, *loaders.UserDetailsLoader) {
  35. cfg := testutil.GetTestConfig()
  36. conn := testutil.GetTestSqlConn()
  37. models := model.NewModels(conn, testutil.GetTestCacheConf(), testutil.GetTestCachePrefix())
  38. rds := redis.MustNewRedis(cfg.CacheRedis.Nodes[0].RedisConf)
  39. loader := loaders.NewUserDetailsLoader(rds, testutil.GetTestCachePrefix(), models)
  40. m := middleware.NewJwtAuthMiddleware(testAccessSecret, loader)
  41. return m, loader
  42. }
  43. func createTestUser(t *testing.T, username string) (int64, func()) {
  44. t.Helper()
  45. ctx := context.Background()
  46. conn := testutil.GetTestSqlConn()
  47. models := model.NewModels(conn, testutil.GetTestCacheConf(), testutil.GetTestCachePrefix())
  48. now := time.Now().Unix()
  49. u := &user.SysUser{
  50. Username: username,
  51. Password: testutil.HashPassword("test123"),
  52. Nickname: "test_nick",
  53. Avatar: sql.NullString{Valid: false},
  54. Email: "[email protected]",
  55. Phone: "13800000000",
  56. Remark: "",
  57. DeptId: 0,
  58. IsSuperAdmin: consts.IsSuperAdminNo,
  59. MustChangePassword: consts.MustChangePasswordNo,
  60. Status: consts.StatusEnabled,
  61. CreateTime: now,
  62. UpdateTime: now,
  63. }
  64. result, err := models.SysUserModel.Insert(ctx, u)
  65. require.NoError(t, err)
  66. userId, err := result.LastInsertId()
  67. require.NoError(t, err)
  68. cleanup := func() {
  69. testutil.CleanTable(ctx, conn, "sys_user", userId)
  70. }
  71. return userId, cleanup
  72. }
  73. func init() {
  74. response.Setup()
  75. }
  76. // TC-0184: `Authorization: Bearer {valid}`
  77. func TestJwtAuthMiddleware_Handle(t *testing.T) {
  78. m, _ := newTestMiddleware()
  79. t.Run("valid token", func(t *testing.T) {
  80. username := "mw_valid_" + testutil.UniqueId()
  81. userId, cleanup := createTestUser(t, username)
  82. defer cleanup()
  83. tokenStr := generateTestToken(testAccessSecret, 3600, &middleware.Claims{
  84. TokenType: consts.TokenTypeAccess,
  85. UserId: userId,
  86. Username: username,
  87. ProductCode: "",
  88. })
  89. var capturedCtx context.Context
  90. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  91. capturedCtx = r.Context()
  92. w.WriteHeader(http.StatusOK)
  93. })
  94. req := httptest.NewRequest(http.MethodPost, "/test", nil)
  95. req.Header.Set("Authorization", "Bearer "+tokenStr)
  96. rr := httptest.NewRecorder()
  97. handler.ServeHTTP(rr, req)
  98. assert.Equal(t, http.StatusOK, rr.Code)
  99. assert.Equal(t, userId, middleware.GetUserId(capturedCtx))
  100. assert.Equal(t, username, middleware.GetUsername(capturedCtx))
  101. assert.Equal(t, "", middleware.GetProductCode(capturedCtx))
  102. assert.Equal(t, "", middleware.GetMemberType(capturedCtx))
  103. assert.False(t, middleware.IsSuperAdmin(capturedCtx))
  104. })
  105. t.Run("no authorization header", func(t *testing.T) {
  106. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  107. t.Fatal("should not reach handler")
  108. })
  109. req := httptest.NewRequest(http.MethodPost, "/test", nil)
  110. rr := httptest.NewRecorder()
  111. handler.ServeHTTP(rr, req)
  112. var body response.Body
  113. err := json.Unmarshal(rr.Body.Bytes(), &body)
  114. require.NoError(t, err)
  115. assert.Equal(t, 401, body.Code)
  116. assert.Equal(t, "未登录", body.Msg)
  117. })
  118. t.Run("no Bearer prefix", func(t *testing.T) {
  119. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  120. t.Fatal("should not reach handler")
  121. })
  122. req := httptest.NewRequest(http.MethodPost, "/test", nil)
  123. req.Header.Set("Authorization", "Basic some-token")
  124. rr := httptest.NewRecorder()
  125. handler.ServeHTTP(rr, req)
  126. var body response.Body
  127. err := json.Unmarshal(rr.Body.Bytes(), &body)
  128. require.NoError(t, err)
  129. assert.Equal(t, 401, body.Code)
  130. assert.Equal(t, "token格式错误", body.Msg)
  131. })
  132. t.Run("invalid token", func(t *testing.T) {
  133. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  134. t.Fatal("should not reach handler")
  135. })
  136. req := httptest.NewRequest(http.MethodPost, "/test", nil)
  137. req.Header.Set("Authorization", "Bearer invalid-token-string")
  138. rr := httptest.NewRecorder()
  139. handler.ServeHTTP(rr, req)
  140. var body response.Body
  141. err := json.Unmarshal(rr.Body.Bytes(), &body)
  142. require.NoError(t, err)
  143. assert.Equal(t, 401, body.Code)
  144. assert.Equal(t, "token无效或已过期", body.Msg)
  145. })
  146. t.Run("wrong secret", func(t *testing.T) {
  147. tokenStr := generateTestToken("wrong-secret", 3600, &middleware.Claims{
  148. TokenType: consts.TokenTypeAccess,
  149. UserId: 1,
  150. })
  151. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  152. t.Fatal("should not reach handler")
  153. })
  154. req := httptest.NewRequest(http.MethodPost, "/test", nil)
  155. req.Header.Set("Authorization", "Bearer "+tokenStr)
  156. rr := httptest.NewRecorder()
  157. handler.ServeHTTP(rr, req)
  158. var body response.Body
  159. err := json.Unmarshal(rr.Body.Bytes(), &body)
  160. require.NoError(t, err)
  161. assert.Equal(t, 401, body.Code)
  162. assert.Equal(t, "token无效或已过期", body.Msg)
  163. })
  164. t.Run("expired token", func(t *testing.T) {
  165. tokenStr := generateTestToken(testAccessSecret, -10, &middleware.Claims{
  166. TokenType: consts.TokenTypeAccess,
  167. UserId: 1,
  168. })
  169. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  170. t.Fatal("should not reach handler")
  171. })
  172. req := httptest.NewRequest(http.MethodPost, "/test", nil)
  173. req.Header.Set("Authorization", "Bearer "+tokenStr)
  174. rr := httptest.NewRecorder()
  175. handler.ServeHTTP(rr, req)
  176. var body response.Body
  177. err := json.Unmarshal(rr.Body.Bytes(), &body)
  178. require.NoError(t, err)
  179. assert.Equal(t, 401, body.Code)
  180. assert.Equal(t, "token无效或已过期", body.Msg)
  181. })
  182. // TC-0434: refresh token 不应被中间件接受
  183. t.Run("refresh token rejected", func(t *testing.T) {
  184. tokenStr := generateTestToken(testAccessSecret, 3600, &middleware.Claims{
  185. TokenType: consts.TokenTypeRefresh,
  186. UserId: 100,
  187. })
  188. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  189. t.Fatal("should not reach handler")
  190. })
  191. req := httptest.NewRequest(http.MethodPost, "/test", nil)
  192. req.Header.Set("Authorization", "Bearer "+tokenStr)
  193. rr := httptest.NewRecorder()
  194. handler.ServeHTTP(rr, req)
  195. var body response.Body
  196. err := json.Unmarshal(rr.Body.Bytes(), &body)
  197. require.NoError(t, err)
  198. assert.Equal(t, 401, body.Code)
  199. assert.Equal(t, "token无效或类型错误", body.Msg)
  200. })
  201. t.Run("frozen user rejected", func(t *testing.T) {
  202. username := "mw_frozen_" + testutil.UniqueId()
  203. userId, cleanup := createTestUser(t, username)
  204. defer cleanup()
  205. ctx := context.Background()
  206. conn := testutil.GetTestSqlConn()
  207. now := time.Now().Unix()
  208. models := model.NewModels(conn, testutil.GetTestCacheConf(), testutil.GetTestCachePrefix())
  209. err := models.SysUserModel.Update(ctx, &user.SysUser{
  210. Id: userId,
  211. Username: username,
  212. Password: testutil.HashPassword("test123"),
  213. Nickname: "test_nick",
  214. Avatar: sql.NullString{Valid: false},
  215. Email: "[email protected]",
  216. Phone: "13800000000",
  217. DeptId: 0,
  218. IsSuperAdmin: consts.IsSuperAdminNo,
  219. MustChangePassword: consts.MustChangePasswordNo,
  220. Status: consts.StatusDisabled,
  221. CreateTime: now,
  222. UpdateTime: now,
  223. })
  224. require.NoError(t, err)
  225. tokenStr := generateTestToken(testAccessSecret, 3600, &middleware.Claims{
  226. TokenType: consts.TokenTypeAccess,
  227. UserId: userId,
  228. Username: username,
  229. ProductCode: "",
  230. })
  231. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  232. t.Fatal("should not reach handler")
  233. })
  234. req := httptest.NewRequest(http.MethodPost, "/test", nil)
  235. req.Header.Set("Authorization", "Bearer "+tokenStr)
  236. rr := httptest.NewRecorder()
  237. handler.ServeHTTP(rr, req)
  238. var body response.Body
  239. err = json.Unmarshal(rr.Body.Bytes(), &body)
  240. require.NoError(t, err)
  241. assert.Equal(t, 403, body.Code)
  242. assert.Equal(t, "账号已被冻结", body.Msg)
  243. })
  244. }
  245. // TC-0258: ctx含userId=100
  246. func TestGetUserId(t *testing.T) {
  247. ctx := context.Background()
  248. assert.Equal(t, int64(0), middleware.GetUserId(ctx))
  249. ctx = middleware.WithUserDetails(ctx, &loaders.UserDetails{UserId: 42})
  250. assert.Equal(t, int64(42), middleware.GetUserId(ctx))
  251. ctx2 := context.Background()
  252. assert.Equal(t, int64(0), middleware.GetUserId(ctx2))
  253. }
  254. // TC-0260: ctx含username="admin"
  255. func TestGetUsername(t *testing.T) {
  256. ctx := context.Background()
  257. assert.Equal(t, "", middleware.GetUsername(ctx))
  258. ctx = middleware.WithUserDetails(ctx, &loaders.UserDetails{Username: "admin"})
  259. assert.Equal(t, "admin", middleware.GetUsername(ctx))
  260. }
  261. // TC-0261: 空ctx
  262. func TestGetProductCode(t *testing.T) {
  263. ctx := context.Background()
  264. assert.Equal(t, "", middleware.GetProductCode(ctx))
  265. ctx = middleware.WithUserDetails(ctx, &loaders.UserDetails{ProductCode: "p1"})
  266. assert.Equal(t, "p1", middleware.GetProductCode(ctx))
  267. }
  268. // TC-0262: ctx含productCode="p1"
  269. func TestGetMemberType(t *testing.T) {
  270. ctx := context.Background()
  271. assert.Equal(t, "", middleware.GetMemberType(ctx))
  272. ctx = middleware.WithUserDetails(ctx, &loaders.UserDetails{MemberType: "ADMIN"})
  273. assert.Equal(t, "ADMIN", middleware.GetMemberType(ctx))
  274. }
  275. // TC-0263: ctx含memberType="ADMIN"
  276. func TestIsSuperAdmin(t *testing.T) {
  277. tests := []struct {
  278. name string
  279. isSuperAdmin bool
  280. want bool
  281. }{
  282. {"is super admin", true, true},
  283. {"is not super admin", false, false},
  284. }
  285. for _, tt := range tests {
  286. t.Run(tt.name, func(t *testing.T) {
  287. ctx := middleware.WithUserDetails(context.Background(), &loaders.UserDetails{IsSuperAdmin: tt.isSuperAdmin})
  288. assert.Equal(t, tt.want, middleware.IsSuperAdmin(ctx))
  289. })
  290. }
  291. t.Run("empty context", func(t *testing.T) {
  292. assert.False(t, middleware.IsSuperAdmin(context.Background()))
  293. })
  294. }
  295. // TC-0189: claims类型断言失败(防御性分支)
  296. // jwt.ParseWithClaims(tokenStr, &Claims{}, keyFunc) 始终将 token.Claims 设为 *Claims,
  297. // 且解析失败时 Handle 已在 err!=nil 分支提前返回,因此 !ok 分支不可达。
  298. func TestJwtAuthMiddleware_Handle_ClaimsTypeAssertionUnreachable(t *testing.T) {
  299. t.Skip("defensive branch: unreachable via jwt.ParseWithClaims — claims is always *Claims")
  300. }
  301. // suppress unused import warning for httpx
  302. var _ = httpx.Error