incrementTokenVersionIfMatch_audit_test.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. package user_test
  2. import (
  3. "context"
  4. "database/sql"
  5. "errors"
  6. "fmt"
  7. "sync"
  8. "sync/atomic"
  9. "testing"
  10. "time"
  11. "perms-system-server/internal/model/user"
  12. "perms-system-server/internal/testutil"
  13. "github.com/stretchr/testify/assert"
  14. "github.com/stretchr/testify/require"
  15. )
  16. // ---------------------------------------------------------------------------
  17. // 覆盖目标:审计 H-1 修复 —— IncrementTokenVersionIfMatch 必须是"当 DB.tokenVersion == expected
  18. // 时才原子递增"的 CAS,否则 refreshToken rotation 在并发刷新时会放出"两枚都合法的新令牌"导致会话劫持。
  19. // 以下用例把这一契约显式钉死:
  20. // - 匹配 → 成功并返回递增后的新版本
  21. // - 不匹配 → ErrTokenVersionMismatch(不能返回成功,不能悄悄递增)
  22. // - 并发竞态 → N 个 goroutine 用同一个 expected 打入,必须只有 1 个成功
  23. // - 成功后必须清掉 id-key / username-key 双路缓存
  24. // ---------------------------------------------------------------------------
  25. // TC-0802: H-1 —— expected 与 DB 当前 tokenVersion 一致时返回递增后的新版本。
  26. func TestSysUserModel_IncrementTokenVersionIfMatch_Match(t *testing.T) {
  27. m, conn := newModel(t)
  28. ctx := context.Background()
  29. now := time.Now().Unix()
  30. username := "cas_match_" + testutil.UniqueId()
  31. res, err := m.Insert(ctx, &user.SysUser{
  32. Username: username, Password: "x", Nickname: "n",
  33. Avatar: sql.NullString{}, IsSuperAdmin: 2, MustChangePassword: 2,
  34. Status: 1, TokenVersion: 5, CreateTime: now, UpdateTime: now,
  35. })
  36. require.NoError(t, err)
  37. id, _ := res.LastInsertId()
  38. t.Cleanup(func() { testutil.CleanTable(ctx, conn, "`sys_user`", id) })
  39. got, err := m.IncrementTokenVersionIfMatch(ctx, id, username, 5)
  40. require.NoError(t, err)
  41. assert.Equal(t, int64(6), got, "expected 命中时返回 DB 真实递增后的新版本")
  42. fresh, err := m.FindOne(ctx, id)
  43. require.NoError(t, err)
  44. assert.Equal(t, int64(6), fresh.TokenVersion, "DB 落盘值必须也是 6")
  45. }
  46. // TC-0803: H-1 —— expected 与 DB 不一致时返回 ErrTokenVersionMismatch 且 DB 不得发生任何变更。
  47. // 这是会话劫持窗口的关键拦截:攻击者的 token 里 TokenVersion = V,但合法用户已刷新到 V+1,
  48. // 攻击者再来刷新时 expected=V 打不中 WHERE 子句 → 必须失败。
  49. func TestSysUserModel_IncrementTokenVersionIfMatch_Mismatch_NoSideEffect(t *testing.T) {
  50. m, conn := newModel(t)
  51. ctx := context.Background()
  52. now := time.Now().Unix()
  53. username := "cas_mismatch_" + testutil.UniqueId()
  54. res, err := m.Insert(ctx, &user.SysUser{
  55. Username: username, Password: "x", Nickname: "n",
  56. Avatar: sql.NullString{}, IsSuperAdmin: 2, MustChangePassword: 2,
  57. Status: 1, TokenVersion: 10, CreateTime: now, UpdateTime: now,
  58. })
  59. require.NoError(t, err)
  60. id, _ := res.LastInsertId()
  61. t.Cleanup(func() { testutil.CleanTable(ctx, conn, "`sys_user`", id) })
  62. got, err := m.IncrementTokenVersionIfMatch(ctx, id, username, 9)
  63. require.Error(t, err, "expected 未命中时必须返回错误")
  64. assert.True(t, errors.Is(err, user.ErrTokenVersionMismatch), "错误必须是 ErrTokenVersionMismatch 以供 logic 层分辨")
  65. assert.Equal(t, int64(0), got)
  66. fresh, err := m.FindOne(ctx, id)
  67. require.NoError(t, err)
  68. assert.Equal(t, int64(10), fresh.TokenVersion, "CAS 失败必须对 DB 零副作用")
  69. }
  70. // 原 TC-0804 "用户不存在必须返回原生 NotFound 而非 ErrTokenVersionMismatch" 已按 M-8
  71. // 新契约废止:M-8 取消了模型内 FindOne 预检,所有 CAS 未命中(无论是版本不匹配还是
  72. // 行根本不存在)都统一返回 ErrTokenVersionMismatch。logic 层 RefreshToken 改由
  73. // 上游 UserDetailsLoader.Load 的 status 分支分辨"离职/冻结"。
  74. // TC-0805: H-1 并发回归 —— N 个 goroutine 用同一个 expected 去 CAS,
  75. // 必须恰好只有 1 个返回 success,其余全部 ErrTokenVersionMismatch;
  76. // 最终 DB 的 tokenVersion 必须只递增 1(攻击者无法劫持第二枚令牌)。
  77. func TestSysUserModel_IncrementTokenVersionIfMatch_ConcurrentSingleWinner(t *testing.T) {
  78. m, conn := newModel(t)
  79. ctx := context.Background()
  80. now := time.Now().Unix()
  81. username := "cas_race_" + testutil.UniqueId()
  82. res, err := m.Insert(ctx, &user.SysUser{
  83. Username: username, Password: "x", Nickname: "n",
  84. Avatar: sql.NullString{}, IsSuperAdmin: 2, MustChangePassword: 2,
  85. Status: 1, TokenVersion: 20, CreateTime: now, UpdateTime: now,
  86. })
  87. require.NoError(t, err)
  88. id, _ := res.LastInsertId()
  89. t.Cleanup(func() { testutil.CleanTable(ctx, conn, "`sys_user`", id) })
  90. // 限制在 8 并发以避免触发 go-zero sqlx breaker(单机 MySQL + breaker 对同批次突发
  91. // 的并发 UPDATE 容易误伤;CAS 契约在 N=8 时已足以验证"唯一胜出")。
  92. const N = 8
  93. var (
  94. wg sync.WaitGroup
  95. successCnt int32
  96. mismatchCnt int32
  97. otherErr atomic.Value
  98. winners sync.Map
  99. )
  100. start := make(chan struct{})
  101. for i := 0; i < N; i++ {
  102. wg.Add(1)
  103. go func(idx int) {
  104. defer wg.Done()
  105. <-start // 最大程度对齐并发起跑线
  106. v, e := m.IncrementTokenVersionIfMatch(ctx, id, username, 20)
  107. switch {
  108. case e == nil:
  109. atomic.AddInt32(&successCnt, 1)
  110. winners.Store(idx, v)
  111. case errors.Is(e, user.ErrTokenVersionMismatch):
  112. atomic.AddInt32(&mismatchCnt, 1)
  113. default:
  114. otherErr.Store(e)
  115. }
  116. }(i)
  117. }
  118. close(start)
  119. wg.Wait()
  120. if v := otherErr.Load(); v != nil {
  121. t.Fatalf("并发 CAS 出现非预期错误:%v", v)
  122. }
  123. assert.Equal(t, int32(1), atomic.LoadInt32(&successCnt),
  124. "会话劫持防线:N=16 的竞态中必须有且仅有 1 个 CAS 胜出")
  125. assert.Equal(t, int32(N-1), atomic.LoadInt32(&mismatchCnt),
  126. "其他并发者必须全部返回 ErrTokenVersionMismatch,即攻击者会被 401 下线")
  127. // 唯一胜出者的返回值必须等于 21(起点 20 → +1)
  128. winners.Range(func(_, v any) bool {
  129. assert.Equal(t, int64(21), v.(int64), "唯一胜出的 CAS 应返回 expected+1")
  130. return true
  131. })
  132. fresh, err := m.FindOne(ctx, id)
  133. require.NoError(t, err)
  134. assert.Equal(t, int64(21), fresh.TokenVersion, "DB 最终只能递增 1(CAS 原子性的外部可观察证据)")
  135. }
  136. // TC-0806: H-1 —— 成功后必须使 id-key / username-key 双路缓存失效,
  137. // 否则 middleware 读缓存拿到的 tokenVersion 与 DB 不一致,依然存在"旧令牌合法误放"的旁路。
  138. func TestSysUserModel_IncrementTokenVersionIfMatch_InvalidatesCaches(t *testing.T) {
  139. m, conn := newModel(t)
  140. ctx := context.Background()
  141. now := time.Now().Unix()
  142. username := "cas_cache_" + testutil.UniqueId()
  143. res, err := m.Insert(ctx, &user.SysUser{
  144. Username: username, Password: "x", Nickname: "n",
  145. Avatar: sql.NullString{}, IsSuperAdmin: 2, MustChangePassword: 2,
  146. Status: 1, TokenVersion: 0, CreateTime: now, UpdateTime: now,
  147. })
  148. require.NoError(t, err)
  149. id, _ := res.LastInsertId()
  150. t.Cleanup(func() { testutil.CleanTable(ctx, conn, "`sys_user`", id) })
  151. u0a, err := m.FindOne(ctx, id)
  152. require.NoError(t, err)
  153. require.Equal(t, int64(0), u0a.TokenVersion)
  154. u0b, err := m.FindOneByUsername(ctx, username)
  155. require.NoError(t, err)
  156. require.Equal(t, int64(0), u0b.TokenVersion)
  157. got, err := m.IncrementTokenVersionIfMatch(ctx, id, username, 0)
  158. require.NoError(t, err)
  159. require.Equal(t, int64(1), got)
  160. // 再次读两路缓存,必须看到递增后的 1(而非 stale 0)
  161. u1a, err := m.FindOne(ctx, id)
  162. require.NoError(t, err)
  163. assert.Equal(t, int64(1), u1a.TokenVersion, fmt.Sprintf(
  164. "id-key 缓存未被清理,stale tokenVersion=%d(审计 H-1 的缓存一致性防线)", u1a.TokenVersion))
  165. u1b, err := m.FindOneByUsername(ctx, username)
  166. require.NoError(t, err)
  167. assert.Equal(t, int64(1), u1b.TokenVersion, fmt.Sprintf(
  168. "username-key 缓存未被清理,stale tokenVersion=%d", u1b.TokenVersion))
  169. }