incrementTokenVersionIfMatch_audit_test.go 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. package user_test
  2. import (
  3. "context"
  4. "database/sql"
  5. "errors"
  6. "fmt"
  7. "sync"
  8. "sync/atomic"
  9. "testing"
  10. "time"
  11. "perms-system-server/internal/model/user"
  12. "perms-system-server/internal/testutil"
  13. "github.com/stretchr/testify/assert"
  14. "github.com/stretchr/testify/require"
  15. )
  16. // ---------------------------------------------------------------------------
  17. // 覆盖目标:审计 H-1 修复 —— IncrementTokenVersionIfMatch 必须是"当 DB.tokenVersion == expected
  18. // 时才原子递增"的 CAS,否则 refreshToken rotation 在并发刷新时会放出"两枚都合法的新令牌"导致会话劫持。
  19. // 以下用例把这一契约显式钉死:
  20. // - 匹配 → 成功并返回递增后的新版本
  21. // - 不匹配 → ErrTokenVersionMismatch(不能返回成功,不能悄悄递增)
  22. // - 并发竞态 → N 个 goroutine 用同一个 expected 打入,必须只有 1 个成功
  23. // - 成功后必须清掉 id-key / username-key 双路缓存
  24. // ---------------------------------------------------------------------------
  25. // TC-0802: H-1 —— expected 与 DB 当前 tokenVersion 一致时返回递增后的新版本。
  26. func TestSysUserModel_IncrementTokenVersionIfMatch_Match(t *testing.T) {
  27. m, conn := newModel(t)
  28. ctx := context.Background()
  29. now := time.Now().Unix()
  30. username := "cas_match_" + testutil.UniqueId()
  31. res, err := m.Insert(ctx, &user.SysUser{
  32. Username: username, Password: "x", Nickname: "n",
  33. Avatar: sql.NullString{}, IsSuperAdmin: 2, MustChangePassword: 2,
  34. Status: 1, TokenVersion: 5, CreateTime: now, UpdateTime: now,
  35. })
  36. require.NoError(t, err)
  37. id, _ := res.LastInsertId()
  38. t.Cleanup(func() { testutil.CleanTable(ctx, conn, "`sys_user`", id) })
  39. got, err := m.IncrementTokenVersionIfMatch(ctx, id, 5)
  40. require.NoError(t, err)
  41. assert.Equal(t, int64(6), got, "expected 命中时返回 DB 真实递增后的新版本")
  42. fresh, err := m.FindOne(ctx, id)
  43. require.NoError(t, err)
  44. assert.Equal(t, int64(6), fresh.TokenVersion, "DB 落盘值必须也是 6")
  45. }
  46. // TC-0803: H-1 —— expected 与 DB 不一致时返回 ErrTokenVersionMismatch 且 DB 不得发生任何变更。
  47. // 这是会话劫持窗口的关键拦截:攻击者的 token 里 TokenVersion = V,但合法用户已刷新到 V+1,
  48. // 攻击者再来刷新时 expected=V 打不中 WHERE 子句 → 必须失败。
  49. func TestSysUserModel_IncrementTokenVersionIfMatch_Mismatch_NoSideEffect(t *testing.T) {
  50. m, conn := newModel(t)
  51. ctx := context.Background()
  52. now := time.Now().Unix()
  53. username := "cas_mismatch_" + testutil.UniqueId()
  54. res, err := m.Insert(ctx, &user.SysUser{
  55. Username: username, Password: "x", Nickname: "n",
  56. Avatar: sql.NullString{}, IsSuperAdmin: 2, MustChangePassword: 2,
  57. Status: 1, TokenVersion: 10, CreateTime: now, UpdateTime: now,
  58. })
  59. require.NoError(t, err)
  60. id, _ := res.LastInsertId()
  61. t.Cleanup(func() { testutil.CleanTable(ctx, conn, "`sys_user`", id) })
  62. got, err := m.IncrementTokenVersionIfMatch(ctx, id, 9)
  63. require.Error(t, err, "expected 未命中时必须返回错误")
  64. assert.True(t, errors.Is(err, user.ErrTokenVersionMismatch), "错误必须是 ErrTokenVersionMismatch 以供 logic 层分辨")
  65. assert.Equal(t, int64(0), got)
  66. fresh, err := m.FindOne(ctx, id)
  67. require.NoError(t, err)
  68. assert.Equal(t, int64(10), fresh.TokenVersion, "CAS 失败必须对 DB 零副作用")
  69. }
  70. // TC-0804: H-1 —— user 不存在时必须返回原生 NotFound 错误(不得被 ErrTokenVersionMismatch 掩盖)。
  71. // 这个边界保证 logic 层能区分"用户被删"(应走 UserDetailsLoader 的 status 分支)和"令牌被接管"。
  72. func TestSysUserModel_IncrementTokenVersionIfMatch_UserNotFound(t *testing.T) {
  73. m, _ := newModel(t)
  74. ctx := context.Background()
  75. got, err := m.IncrementTokenVersionIfMatch(ctx, 999999998, 0)
  76. require.Error(t, err)
  77. assert.False(t, errors.Is(err, user.ErrTokenVersionMismatch),
  78. "用户不存在的错误不得伪装成 TokenVersionMismatch,避免混淆 logic 层的分支")
  79. assert.Equal(t, int64(0), got)
  80. }
  81. // TC-0805: H-1 并发回归 —— N 个 goroutine 用同一个 expected 去 CAS,
  82. // 必须恰好只有 1 个返回 success,其余全部 ErrTokenVersionMismatch;
  83. // 最终 DB 的 tokenVersion 必须只递增 1(攻击者无法劫持第二枚令牌)。
  84. func TestSysUserModel_IncrementTokenVersionIfMatch_ConcurrentSingleWinner(t *testing.T) {
  85. m, conn := newModel(t)
  86. ctx := context.Background()
  87. now := time.Now().Unix()
  88. username := "cas_race_" + testutil.UniqueId()
  89. res, err := m.Insert(ctx, &user.SysUser{
  90. Username: username, Password: "x", Nickname: "n",
  91. Avatar: sql.NullString{}, IsSuperAdmin: 2, MustChangePassword: 2,
  92. Status: 1, TokenVersion: 20, CreateTime: now, UpdateTime: now,
  93. })
  94. require.NoError(t, err)
  95. id, _ := res.LastInsertId()
  96. t.Cleanup(func() { testutil.CleanTable(ctx, conn, "`sys_user`", id) })
  97. // 限制在 8 并发以避免触发 go-zero sqlx breaker(单机 MySQL + breaker 对同批次突发
  98. // 的并发 UPDATE 容易误伤;CAS 契约在 N=8 时已足以验证"唯一胜出")。
  99. const N = 8
  100. var (
  101. wg sync.WaitGroup
  102. successCnt int32
  103. mismatchCnt int32
  104. otherErr atomic.Value
  105. winners sync.Map
  106. )
  107. start := make(chan struct{})
  108. for i := 0; i < N; i++ {
  109. wg.Add(1)
  110. go func(idx int) {
  111. defer wg.Done()
  112. <-start // 最大程度对齐并发起跑线
  113. v, e := m.IncrementTokenVersionIfMatch(ctx, id, 20)
  114. switch {
  115. case e == nil:
  116. atomic.AddInt32(&successCnt, 1)
  117. winners.Store(idx, v)
  118. case errors.Is(e, user.ErrTokenVersionMismatch):
  119. atomic.AddInt32(&mismatchCnt, 1)
  120. default:
  121. otherErr.Store(e)
  122. }
  123. }(i)
  124. }
  125. close(start)
  126. wg.Wait()
  127. if v := otherErr.Load(); v != nil {
  128. t.Fatalf("并发 CAS 出现非预期错误:%v", v)
  129. }
  130. assert.Equal(t, int32(1), atomic.LoadInt32(&successCnt),
  131. "会话劫持防线:N=16 的竞态中必须有且仅有 1 个 CAS 胜出")
  132. assert.Equal(t, int32(N-1), atomic.LoadInt32(&mismatchCnt),
  133. "其他并发者必须全部返回 ErrTokenVersionMismatch,即攻击者会被 401 下线")
  134. // 唯一胜出者的返回值必须等于 21(起点 20 → +1)
  135. winners.Range(func(_, v any) bool {
  136. assert.Equal(t, int64(21), v.(int64), "唯一胜出的 CAS 应返回 expected+1")
  137. return true
  138. })
  139. fresh, err := m.FindOne(ctx, id)
  140. require.NoError(t, err)
  141. assert.Equal(t, int64(21), fresh.TokenVersion, "DB 最终只能递增 1(CAS 原子性的外部可观察证据)")
  142. }
  143. // TC-0806: H-1 —— 成功后必须使 id-key / username-key 双路缓存失效,
  144. // 否则 middleware 读缓存拿到的 tokenVersion 与 DB 不一致,依然存在"旧令牌合法误放"的旁路。
  145. func TestSysUserModel_IncrementTokenVersionIfMatch_InvalidatesCaches(t *testing.T) {
  146. m, conn := newModel(t)
  147. ctx := context.Background()
  148. now := time.Now().Unix()
  149. username := "cas_cache_" + testutil.UniqueId()
  150. res, err := m.Insert(ctx, &user.SysUser{
  151. Username: username, Password: "x", Nickname: "n",
  152. Avatar: sql.NullString{}, IsSuperAdmin: 2, MustChangePassword: 2,
  153. Status: 1, TokenVersion: 0, CreateTime: now, UpdateTime: now,
  154. })
  155. require.NoError(t, err)
  156. id, _ := res.LastInsertId()
  157. t.Cleanup(func() { testutil.CleanTable(ctx, conn, "`sys_user`", id) })
  158. u0a, err := m.FindOne(ctx, id)
  159. require.NoError(t, err)
  160. require.Equal(t, int64(0), u0a.TokenVersion)
  161. u0b, err := m.FindOneByUsername(ctx, username)
  162. require.NoError(t, err)
  163. require.Equal(t, int64(0), u0b.TokenVersion)
  164. got, err := m.IncrementTokenVersionIfMatch(ctx, id, 0)
  165. require.NoError(t, err)
  166. require.Equal(t, int64(1), got)
  167. // 再次读两路缓存,必须看到递增后的 1(而非 stale 0)
  168. u1a, err := m.FindOne(ctx, id)
  169. require.NoError(t, err)
  170. assert.Equal(t, int64(1), u1a.TokenVersion, fmt.Sprintf(
  171. "id-key 缓存未被清理,stale tokenVersion=%d(审计 H-1 的缓存一致性防线)", u1a.TokenVersion))
  172. u1b, err := m.FindOneByUsername(ctx, username)
  173. require.NoError(t, err)
  174. assert.Equal(t, int64(1), u1b.TokenVersion, fmt.Sprintf(
  175. "username-key 缓存未被清理,stale tokenVersion=%d", u1b.TokenVersion))
  176. }