refreshTokenCas_audit_test.go 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. package pub
  2. import (
  3. "context"
  4. "errors"
  5. "sync"
  6. "sync/atomic"
  7. "testing"
  8. authHelper "perms-system-server/internal/logic/auth"
  9. "perms-system-server/internal/response"
  10. "perms-system-server/internal/testutil"
  11. "perms-system-server/internal/types"
  12. "github.com/stretchr/testify/assert"
  13. "github.com/stretchr/testify/require"
  14. )
  15. // ---------------------------------------------------------------------------
  16. // 覆盖目标:审计 H-1 修复的 logic 层回归 —— 在 logic 里用 CAS 递增 tokenVersion。
  17. // 该文件聚焦"并发 refresh 同一旧令牌时的行为":
  18. // 1) N 个并发 RefreshToken 共用同一把 claims.TokenVersion=0 的 refreshToken,
  19. // 必须恰好 1 个返回成功;其余 N-1 个被 401 拒绝(字样必为"登录状态已失效")。
  20. // 2) DB 的 tokenVersion 最终只能递增 1;
  21. // 3) 明确 CAS 失败时返回的 401 错误是通过 ErrTokenVersionMismatch 路径产出,
  22. // 与"账号冻结"等 403 分支互不混用。
  23. // ---------------------------------------------------------------------------
  24. // TC-0812: H-1 logic 并发回归 —— 并发重放同一个旧 refreshToken,只允许一位胜出。
  25. func TestRefreshToken_ConcurrentSameToken_SingleWinner(t *testing.T) {
  26. ctx := context.Background()
  27. svcCtx := newTestSvcCtx()
  28. username := "rt_cas_" + testutil.UniqueId()
  29. userId, cleanUser := insertRefreshTestUser(t, ctx, username, "TestPass123", 1, 2)
  30. t.Cleanup(cleanUser)
  31. // 禁用 TokenOpLimiter,以让本测试的变量只剩"并发 CAS 胜负"。
  32. svcCtx.TokenOpLimiter = nil
  33. rt, err := authHelper.GenerateRefreshToken(
  34. svcCtx.Config.Auth.RefreshSecret, svcCtx.Config.Auth.RefreshExpire,
  35. userId, "", 0,
  36. )
  37. require.NoError(t, err)
  38. // 限制在 6 并发以避免触发 go-zero sqlx breaker(单机 MySQL + breaker 对同批次突发
  39. // 的并发 UPDATE 容易误伤,生产里 refreshToken 也是 per-user 限频 + CAS 双层保护,
  40. // 没机会打成这么高的并发)。CAS "唯一胜出" 的契约在 N=6 时已足以钉死。
  41. const N = 6
  42. var (
  43. wg sync.WaitGroup
  44. okCount int32
  45. authFailCnt int32
  46. otherErr atomic.Value
  47. )
  48. start := make(chan struct{})
  49. for i := 0; i < N; i++ {
  50. wg.Add(1)
  51. go func() {
  52. defer wg.Done()
  53. <-start
  54. resp, e := NewRefreshTokenLogic(ctx, svcCtx).RefreshToken(&types.RefreshTokenReq{
  55. Authorization: "Bearer " + rt,
  56. })
  57. switch {
  58. case e == nil && resp != nil:
  59. atomic.AddInt32(&okCount, 1)
  60. case e != nil:
  61. var ce *response.CodeError
  62. if errors.As(e, &ce) && ce.Code() == 401 &&
  63. ce.Error() == "登录状态已失效,请重新登录" {
  64. atomic.AddInt32(&authFailCnt, 1)
  65. } else {
  66. otherErr.Store(e)
  67. }
  68. }
  69. }()
  70. }
  71. close(start)
  72. wg.Wait()
  73. if v := otherErr.Load(); v != nil {
  74. t.Fatalf("并发 RefreshToken 出现非预期错误:%v", v)
  75. }
  76. assert.Equal(t, int32(1), atomic.LoadInt32(&okCount),
  77. "H-1 会话劫持防线:重放同一旧 refreshToken 的 N 个并发请求必须只有 1 个成功")
  78. assert.Equal(t, int32(N-1), atomic.LoadInt32(&authFailCnt),
  79. "其他并发者必须返回 401 '登录状态已失效'")
  80. // DB 必然只递增 1。
  81. u, err := svcCtx.SysUserModel.FindOne(ctx, userId)
  82. require.NoError(t, err)
  83. assert.Equal(t, int64(1), u.TokenVersion,
  84. "DB tokenVersion 递增幅度就是 CAS 成功次数 → 只能是 1")
  85. }