jwt_test.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. package auth
  2. import (
  3. "crypto/hmac"
  4. "crypto/sha256"
  5. "encoding/base64"
  6. "encoding/json"
  7. "github.com/golang-jwt/jwt/v4"
  8. "github.com/stretchr/testify/assert"
  9. "github.com/stretchr/testify/require"
  10. "perms-system-server/internal/consts"
  11. "perms-system-server/internal/middleware"
  12. "strings"
  13. "testing"
  14. "time"
  15. )
  16. const testSecret = "test-jwt-secret-key"
  17. // TC-0292: secret="s", expire=3600, userId=1, username="u", productCode="p", memberType=""
  18. func TestGenerateAccessToken(t *testing.T) {
  19. tests := []struct {
  20. name string
  21. secret string
  22. expire int64
  23. userId int64
  24. username string
  25. productCode string
  26. memberType string
  27. tokenVersion int64
  28. }{
  29. {
  30. name: "normal generation",
  31. secret: testSecret,
  32. expire: 3600,
  33. userId: 1,
  34. username: "admin",
  35. productCode: "p1",
  36. memberType: "ADMIN",
  37. },
  38. {
  39. name: "empty productCode",
  40. secret: testSecret,
  41. expire: 3600,
  42. userId: 3,
  43. username: "user2",
  44. productCode: "",
  45. memberType: "",
  46. },
  47. {
  48. name: "super admin with tokenVersion",
  49. secret: testSecret,
  50. expire: 7200,
  51. userId: 100,
  52. username: "super",
  53. productCode: "p1",
  54. memberType: "SUPER_ADMIN",
  55. tokenVersion: 5,
  56. },
  57. }
  58. for _, tt := range tests {
  59. t.Run(tt.name, func(t *testing.T) {
  60. tokenStr, err := GenerateAccessToken(tt.secret, tt.expire, tt.userId, tt.username, tt.productCode, tt.memberType, tt.tokenVersion)
  61. require.NoError(t, err)
  62. assert.NotEmpty(t, tokenStr)
  63. token, err := jwt.ParseWithClaims(tokenStr, &middleware.Claims{}, func(token *jwt.Token) (interface{}, error) {
  64. return []byte(tt.secret), nil
  65. })
  66. require.NoError(t, err)
  67. assert.True(t, token.Valid)
  68. claims, ok := token.Claims.(*middleware.Claims)
  69. require.True(t, ok)
  70. assert.Equal(t, tt.userId, claims.UserId)
  71. assert.Equal(t, tt.username, claims.Username)
  72. assert.Equal(t, tt.productCode, claims.ProductCode)
  73. assert.Equal(t, tt.memberType, claims.MemberType)
  74. assert.Equal(t, tt.tokenVersion, claims.TokenVersion)
  75. // 项 :`perms` 字段已从 Claims 结构体中移除。
  76. // 解析原始 JWT payload,确保 token JSON 中不存在 "perms" key。
  77. segments := strings.Split(tokenStr, ".")
  78. require.Len(t, segments, 3, "jwt must have 3 segments")
  79. payloadBytes, err := base64.RawURLEncoding.DecodeString(segments[1])
  80. require.NoError(t, err)
  81. var raw map[string]interface{}
  82. require.NoError(t, json.Unmarshal(payloadBytes, &raw))
  83. _, hasPerms := raw["perms"]
  84. assert.False(t, hasPerms, "access token payload must NOT contain perms field")
  85. })
  86. }
  87. }
  88. // TC-0296: expireSeconds=1, sleep 2s
  89. func TestGenerateAccessToken_Expiry(t *testing.T) {
  90. tokenStr, err := GenerateAccessToken(testSecret, 1, 1, "u", "", "", 0)
  91. require.NoError(t, err)
  92. time.Sleep(2 * time.Second)
  93. _, err = jwt.ParseWithClaims(tokenStr, &middleware.Claims{}, func(token *jwt.Token) (interface{}, error) {
  94. return []byte(testSecret), nil
  95. })
  96. assert.Error(t, err)
  97. assert.Contains(t, err.Error(), "token is expired")
  98. }
  99. // TC-0297: secret="s", expire=86400, userId=1, productCode="p"
  100. func TestGenerateRefreshToken(t *testing.T) {
  101. tests := []struct {
  102. name string
  103. secret string
  104. expire int64
  105. userId int64
  106. productCode string
  107. }{
  108. {"normal", testSecret, 86400, 1, "p1"},
  109. {"empty productCode", testSecret, 86400, 2, ""},
  110. }
  111. for _, tt := range tests {
  112. t.Run(tt.name, func(t *testing.T) {
  113. tokenStr, err := GenerateRefreshToken(tt.secret, tt.expire, tt.userId, tt.productCode, 0)
  114. require.NoError(t, err)
  115. assert.NotEmpty(t, tokenStr)
  116. claims, err := ParseRefreshToken(tokenStr, tt.secret)
  117. require.NoError(t, err)
  118. assert.Equal(t, tt.userId, claims.UserId)
  119. assert.Equal(t, tt.productCode, claims.ProductCode)
  120. })
  121. }
  122. }
  123. // TC-0300: 有效token+正确secret
  124. func TestParseRefreshToken(t *testing.T) {
  125. validToken, err := GenerateRefreshToken(testSecret, 3600, 42, "prod", 0)
  126. require.NoError(t, err)
  127. t.Run("valid token", func(t *testing.T) {
  128. claims, err := ParseRefreshToken(validToken, testSecret)
  129. require.NoError(t, err)
  130. assert.Equal(t, int64(42), claims.UserId)
  131. assert.Equal(t, "prod", claims.ProductCode)
  132. })
  133. t.Run("wrong secret", func(t *testing.T) {
  134. _, err := ParseRefreshToken(validToken, "wrong-secret")
  135. assert.Error(t, err)
  136. })
  137. t.Run("invalid token string", func(t *testing.T) {
  138. _, err := ParseRefreshToken("not-a-valid-token", testSecret)
  139. assert.Error(t, err)
  140. })
  141. t.Run("empty token", func(t *testing.T) {
  142. _, err := ParseRefreshToken("", testSecret)
  143. assert.Error(t, err)
  144. })
  145. t.Run("expired token", func(t *testing.T) {
  146. expiredToken, err := GenerateRefreshToken(testSecret, 1, 1, "p", 0)
  147. require.NoError(t, err)
  148. time.Sleep(2 * time.Second)
  149. _, err = ParseRefreshToken(expiredToken, testSecret)
  150. assert.Error(t, err)
  151. })
  152. // TC-0305: AccessToken误用 — TokenType校验拒绝
  153. t.Run("access token used as refresh - should be rejected", func(t *testing.T) {
  154. accessToken, err := GenerateAccessToken(testSecret, 3600, 1, "u", "p", "M", 0)
  155. require.NoError(t, err)
  156. _, err = ParseRefreshToken(accessToken, testSecret)
  157. assert.Error(t, err, "BUG-002: access token 不应被 ParseRefreshToken 接受,应通过 TokenType 字段区分")
  158. })
  159. }
  160. // TC-0294: secret=""
  161. func TestGenerateAccessToken_EmptySecret(t *testing.T) {
  162. tokenStr, err := GenerateAccessToken("", 3600, 1, "u", "p", "M", 0)
  163. require.NoError(t, err)
  164. assert.NotEmpty(t, tokenStr)
  165. token, err := jwt.ParseWithClaims(tokenStr, &middleware.Claims{}, func(token *jwt.Token) (interface{}, error) {
  166. return []byte(""), nil
  167. })
  168. require.NoError(t, err)
  169. assert.True(t, token.Valid)
  170. claims, ok := token.Claims.(*middleware.Claims)
  171. require.True(t, ok)
  172. assert.Equal(t, int64(1), claims.UserId)
  173. }
  174. // ---------------------------------------------------------------------------
  175. // 覆盖目标:ParseWithHMAC 必须显式断言 token.Method 为
  176. // *jwt.SigningMethodHMAC,拒绝任何非 HMAC 的 alg 头,包括 "none" / "RS256" 等。
  177. // 这里不等同于 jwt-go v4 对 "alg=none" 的默认拒绝,而是深度防御的显式白名单校验,
  178. // 杜绝未来迁移到 RSA/ECDSA 时攻击者把公钥当共享密钥伪造 HS256 token
  179. // (CVE-2016-10555 同类问题、OWASP JWT / RFC 8725 要求)。
  180. // ---------------------------------------------------------------------------
  181. const h4Secret = "h4-audit-secret-key"
  182. // b64url returns the jwt-style base64url (no padding) encoding.
  183. func b64url(b []byte) string { return base64.RawURLEncoding.EncodeToString(b) }
  184. // forgeToken 手动拼接一个 JWT:自定义 header.alg + payload,再用任意密钥做 HMAC 签名。
  185. // 这用于模拟"攻击者伪造头部 alg 但签名仍走 HS256"的场景。
  186. func forgeToken(t *testing.T, alg string, claims any, signingKey string) string {
  187. t.Helper()
  188. header := map[string]string{"alg": alg, "typ": "JWT"}
  189. hBytes, err := json.Marshal(header)
  190. require.NoError(t, err)
  191. pBytes, err := json.Marshal(claims)
  192. require.NoError(t, err)
  193. signingInput := b64url(hBytes) + "." + b64url(pBytes)
  194. mac := hmac.New(sha256.New, []byte(signingKey))
  195. mac.Write([]byte(signingInput))
  196. sig := mac.Sum(nil)
  197. return signingInput + "." + b64url(sig)
  198. }
  199. // forgeTokenNoSig 拼接一个没有签名的 token(alg=none 典型攻击,第三段签名留空)。
  200. func forgeTokenNoSig(t *testing.T, alg string, claims any) string {
  201. t.Helper()
  202. header := map[string]string{"alg": alg, "typ": "JWT"}
  203. hBytes, err := json.Marshal(header)
  204. require.NoError(t, err)
  205. pBytes, err := json.Marshal(claims)
  206. require.NoError(t, err)
  207. return b64url(hBytes) + "." + b64url(pBytes) + "."
  208. }
  209. // validRefreshClaims 返回一组完整、未过期的 refresh claims,用于伪造攻击 token。
  210. func validRefreshClaims() RefreshClaims {
  211. now := time.Now()
  212. return RefreshClaims{
  213. TokenType: consts.TokenTypeRefresh,
  214. UserId: 7,
  215. ProductCode: "h4_pc",
  216. TokenVersion: 0,
  217. RegisteredClaims: jwt.RegisteredClaims{
  218. ExpiresAt: jwt.NewNumericDate(now.Add(1 * time.Hour)),
  219. IssuedAt: jwt.NewNumericDate(now),
  220. },
  221. }
  222. }
  223. // TC-0951: 正常 HS256 token 必须被 ParseWithHMAC 正确接受。
  224. func TestParseWithHMAC_HS256_Valid(t *testing.T) {
  225. tok, err := GenerateRefreshToken(h4Secret, 3600, 7, "h4_pc", 0)
  226. require.NoError(t, err)
  227. token, err := ParseWithHMAC(tok, h4Secret, &RefreshClaims{})
  228. require.NoError(t, err)
  229. assert.True(t, token.Valid)
  230. claims, ok := token.Claims.(*RefreshClaims)
  231. require.True(t, ok)
  232. assert.Equal(t, int64(7), claims.UserId)
  233. assert.Equal(t, consts.TokenTypeRefresh, claims.TokenType)
  234. }
  235. // TC-0952: alg=none 的伪造 token 必须被拒绝。
  236. // jwt-go v4 默认就会拦住 "none",但显式 HMAC 断言保证即使 lib 行为变化我们仍 fail-close。
  237. func TestParseWithHMAC_AlgNone_Rejected(t *testing.T) {
  238. forged := forgeTokenNoSig(t, "none", validRefreshClaims())
  239. _, err := ParseWithHMAC(forged, h4Secret, &RefreshClaims{})
  240. require.Error(t, err, "alg=none 必须被 ParseWithHMAC 拒绝")
  241. }
  242. // TC-0953: 攻击者把 header alg 改成 RS256 但仍用 secret 作 HS256 签名
  243. // (RSA 公钥 → HMAC secret 混淆攻击)。必须被 ParseWithHMAC 显式拒绝:
  244. // 命中 keyfunc 的 `token.Method.(*SigningMethodHMAC)` 断言失败分支。
  245. func TestParseWithHMAC_RS256HeaderButHMACSigned_Rejected(t *testing.T) {
  246. forged := forgeToken(t, "RS256", validRefreshClaims(), h4Secret)
  247. _, err := ParseWithHMAC(forged, h4Secret, &RefreshClaims{})
  248. require.Error(t, err, "alg=RS256 必须被 ParseWithHMAC 拒绝")
  249. assert.Contains(t, err.Error(), "unexpected signing method",
  250. "错误信息必须明确指出 alg 与预期不符(便于运维快速定位攻击尝试)")
  251. }
  252. // TC-0954: alg=ES256 同样应被拒绝(非 HMAC 算法一律拒绝)。
  253. func TestParseWithHMAC_ES256HeaderButHMACSigned_Rejected(t *testing.T) {
  254. forged := forgeToken(t, "ES256", validRefreshClaims(), h4Secret)
  255. _, err := ParseWithHMAC(forged, h4Secret, &RefreshClaims{})
  256. require.Error(t, err)
  257. assert.Contains(t, err.Error(), "unexpected signing method")
  258. }
  259. // TC-0955: alg=HS256 但用错误的 secret 签名应被拒绝(签名校验失败路径)。
  260. func TestParseWithHMAC_HS256WrongSecret_Rejected(t *testing.T) {
  261. tok, err := GenerateRefreshToken("attacker-guessed-secret", 3600, 7, "h4_pc", 0)
  262. require.NoError(t, err)
  263. _, err = ParseWithHMAC(tok, h4Secret, &RefreshClaims{})
  264. require.Error(t, err, "签名校验失败必须回错,不得放行")
  265. }
  266. // TC-0956: ParseRefreshToken(对外真实入口)也走 HMAC 断言,alg=RS256 必须被拒。
  267. // 保证 ParseWithHMAC 不是孤立函数,而是已被真实调用链使用。
  268. func TestParseRefreshToken_RS256Header_Rejected(t *testing.T) {
  269. forged := forgeToken(t, "RS256", validRefreshClaims(), h4Secret)
  270. _, err := ParseRefreshToken(forged, h4Secret)
  271. require.Error(t, err, "ParseRefreshToken 必须转交 ParseWithHMAC 拒绝 RS256 伪造 token")
  272. }
  273. // TC-0957: ParseRefreshToken 对 alg=none 的 token 也必须拒绝。
  274. func TestParseRefreshToken_AlgNone_Rejected(t *testing.T) {
  275. forged := forgeTokenNoSig(t, "none", validRefreshClaims())
  276. _, err := ParseRefreshToken(forged, h4Secret)
  277. require.Error(t, err)
  278. }
  279. // TC-0958: 回归 —— 格式错误的 token(非三段式)必须 error 而不是 panic。
  280. func TestParseWithHMAC_Malformed_Rejected(t *testing.T) {
  281. cases := []string{
  282. "",
  283. "not-a-token",
  284. "only.two",
  285. "a.b.c.d", // 四段
  286. }
  287. for _, s := range cases {
  288. t.Run("malformed:"+s, func(t *testing.T) {
  289. _, err := ParseWithHMAC(s, h4Secret, &RefreshClaims{})
  290. require.Error(t, err)
  291. })
  292. }
  293. }
  294. // TC-0959: payload 中 TokenType 非 refresh 的 HS256 token 应被 ParseRefreshToken
  295. // 以 ErrTokenTypeMismatch 拒绝。确认 修复不会误吞该业务校验。
  296. func TestParseRefreshToken_AccessTokenRejectedWithTypeMismatch(t *testing.T) {
  297. accessTok, err := GenerateAccessToken(h4Secret, 3600, 7, "u", "p", "M", 0)
  298. require.NoError(t, err)
  299. _, err = ParseRefreshToken(accessTok, h4Secret)
  300. require.Error(t, err)
  301. assert.Equal(t, ErrTokenTypeMismatch, err,
  302. "的 ParseWithHMAC 不能吞掉业务层 TokenType 校验错误")
  303. }
  304. // TC-0960: 伪造 alg=HS256 但 header.typ 异常(如 "JWT"→"xxx")也不能绕过
  305. // HMAC 校验。此用例用来证明只要底层签名正确,header 其余字段不影响放行/拒绝的核心语义。
  306. // 反之,任何 alg 头不是 HS* 的一律拒,和 typ 无关。
  307. func TestParseWithHMAC_HS256UnusualTyp_Accepted(t *testing.T) {
  308. // header.alg = HS256, header.typ = "JWT+weird",签名正确 → 应放行(typ 不参与断言)
  309. header := map[string]string{"alg": "HS256", "typ": "JWT+weird"}
  310. hBytes, _ := json.Marshal(header)
  311. claims := validRefreshClaims()
  312. pBytes, _ := json.Marshal(claims)
  313. signingInput := b64url(hBytes) + "." + b64url(pBytes)
  314. mac := hmac.New(sha256.New, []byte(h4Secret))
  315. mac.Write([]byte(signingInput))
  316. tok := signingInput + "." + b64url(mac.Sum(nil))
  317. _, err := ParseWithHMAC(tok, h4Secret, &RefreshClaims{})
  318. require.NoError(t, err,
  319. "HMAC 断言只看 alg,typ 不属于签名算法白名单范畴,正常 HS256 应放行")
  320. }
  321. // 辅助:保持 strings 导入被使用,避免 go vet 警告。
  322. var _ = strings.Split
  323. // 确保 middleware.Claims 在包内可被用于 TypeRefresh / TypeAccess 等正反测试(未来扩展)。
  324. var _ = middleware.Claims{}