| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203 |
- package user_test
- import (
- "context"
- "database/sql"
- "errors"
- "fmt"
- "sync"
- "sync/atomic"
- "testing"
- "time"
- "perms-system-server/internal/model/user"
- "perms-system-server/internal/testutil"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
- )
- // ---------------------------------------------------------------------------
- // 覆盖目标:审计 H-1 修复 —— IncrementTokenVersionIfMatch 必须是"当 DB.tokenVersion == expected
- // 时才原子递增"的 CAS,否则 refreshToken rotation 在并发刷新时会放出"两枚都合法的新令牌"导致会话劫持。
- // 以下用例把这一契约显式钉死:
- // - 匹配 → 成功并返回递增后的新版本
- // - 不匹配 → ErrTokenVersionMismatch(不能返回成功,不能悄悄递增)
- // - 并发竞态 → N 个 goroutine 用同一个 expected 打入,必须只有 1 个成功
- // - 成功后必须清掉 id-key / username-key 双路缓存
- // ---------------------------------------------------------------------------
- // TC-0802: H-1 —— expected 与 DB 当前 tokenVersion 一致时返回递增后的新版本。
- func TestSysUserModel_IncrementTokenVersionIfMatch_Match(t *testing.T) {
- m, conn := newModel(t)
- ctx := context.Background()
- now := time.Now().Unix()
- username := "cas_match_" + testutil.UniqueId()
- res, err := m.Insert(ctx, &user.SysUser{
- Username: username, Password: "x", Nickname: "n",
- Avatar: sql.NullString{}, IsSuperAdmin: 2, MustChangePassword: 2,
- Status: 1, TokenVersion: 5, CreateTime: now, UpdateTime: now,
- })
- require.NoError(t, err)
- id, _ := res.LastInsertId()
- t.Cleanup(func() { testutil.CleanTable(ctx, conn, "`sys_user`", id) })
- got, err := m.IncrementTokenVersionIfMatch(ctx, id, 5)
- require.NoError(t, err)
- assert.Equal(t, int64(6), got, "expected 命中时返回 DB 真实递增后的新版本")
- fresh, err := m.FindOne(ctx, id)
- require.NoError(t, err)
- assert.Equal(t, int64(6), fresh.TokenVersion, "DB 落盘值必须也是 6")
- }
- // TC-0803: H-1 —— expected 与 DB 不一致时返回 ErrTokenVersionMismatch 且 DB 不得发生任何变更。
- // 这是会话劫持窗口的关键拦截:攻击者的 token 里 TokenVersion = V,但合法用户已刷新到 V+1,
- // 攻击者再来刷新时 expected=V 打不中 WHERE 子句 → 必须失败。
- func TestSysUserModel_IncrementTokenVersionIfMatch_Mismatch_NoSideEffect(t *testing.T) {
- m, conn := newModel(t)
- ctx := context.Background()
- now := time.Now().Unix()
- username := "cas_mismatch_" + testutil.UniqueId()
- res, err := m.Insert(ctx, &user.SysUser{
- Username: username, Password: "x", Nickname: "n",
- Avatar: sql.NullString{}, IsSuperAdmin: 2, MustChangePassword: 2,
- Status: 1, TokenVersion: 10, CreateTime: now, UpdateTime: now,
- })
- require.NoError(t, err)
- id, _ := res.LastInsertId()
- t.Cleanup(func() { testutil.CleanTable(ctx, conn, "`sys_user`", id) })
- got, err := m.IncrementTokenVersionIfMatch(ctx, id, 9)
- require.Error(t, err, "expected 未命中时必须返回错误")
- assert.True(t, errors.Is(err, user.ErrTokenVersionMismatch), "错误必须是 ErrTokenVersionMismatch 以供 logic 层分辨")
- assert.Equal(t, int64(0), got)
- fresh, err := m.FindOne(ctx, id)
- require.NoError(t, err)
- assert.Equal(t, int64(10), fresh.TokenVersion, "CAS 失败必须对 DB 零副作用")
- }
- // TC-0804: H-1 —— user 不存在时必须返回原生 NotFound 错误(不得被 ErrTokenVersionMismatch 掩盖)。
- // 这个边界保证 logic 层能区分"用户被删"(应走 UserDetailsLoader 的 status 分支)和"令牌被接管"。
- func TestSysUserModel_IncrementTokenVersionIfMatch_UserNotFound(t *testing.T) {
- m, _ := newModel(t)
- ctx := context.Background()
- got, err := m.IncrementTokenVersionIfMatch(ctx, 999999998, 0)
- require.Error(t, err)
- assert.False(t, errors.Is(err, user.ErrTokenVersionMismatch),
- "用户不存在的错误不得伪装成 TokenVersionMismatch,避免混淆 logic 层的分支")
- assert.Equal(t, int64(0), got)
- }
- // TC-0805: H-1 并发回归 —— N 个 goroutine 用同一个 expected 去 CAS,
- // 必须恰好只有 1 个返回 success,其余全部 ErrTokenVersionMismatch;
- // 最终 DB 的 tokenVersion 必须只递增 1(攻击者无法劫持第二枚令牌)。
- func TestSysUserModel_IncrementTokenVersionIfMatch_ConcurrentSingleWinner(t *testing.T) {
- m, conn := newModel(t)
- ctx := context.Background()
- now := time.Now().Unix()
- username := "cas_race_" + testutil.UniqueId()
- res, err := m.Insert(ctx, &user.SysUser{
- Username: username, Password: "x", Nickname: "n",
- Avatar: sql.NullString{}, IsSuperAdmin: 2, MustChangePassword: 2,
- Status: 1, TokenVersion: 20, CreateTime: now, UpdateTime: now,
- })
- require.NoError(t, err)
- id, _ := res.LastInsertId()
- t.Cleanup(func() { testutil.CleanTable(ctx, conn, "`sys_user`", id) })
- // 限制在 8 并发以避免触发 go-zero sqlx breaker(单机 MySQL + breaker 对同批次突发
- // 的并发 UPDATE 容易误伤;CAS 契约在 N=8 时已足以验证"唯一胜出")。
- const N = 8
- var (
- wg sync.WaitGroup
- successCnt int32
- mismatchCnt int32
- otherErr atomic.Value
- winners sync.Map
- )
- start := make(chan struct{})
- for i := 0; i < N; i++ {
- wg.Add(1)
- go func(idx int) {
- defer wg.Done()
- <-start // 最大程度对齐并发起跑线
- v, e := m.IncrementTokenVersionIfMatch(ctx, id, 20)
- switch {
- case e == nil:
- atomic.AddInt32(&successCnt, 1)
- winners.Store(idx, v)
- case errors.Is(e, user.ErrTokenVersionMismatch):
- atomic.AddInt32(&mismatchCnt, 1)
- default:
- otherErr.Store(e)
- }
- }(i)
- }
- close(start)
- wg.Wait()
- if v := otherErr.Load(); v != nil {
- t.Fatalf("并发 CAS 出现非预期错误:%v", v)
- }
- assert.Equal(t, int32(1), atomic.LoadInt32(&successCnt),
- "会话劫持防线:N=16 的竞态中必须有且仅有 1 个 CAS 胜出")
- assert.Equal(t, int32(N-1), atomic.LoadInt32(&mismatchCnt),
- "其他并发者必须全部返回 ErrTokenVersionMismatch,即攻击者会被 401 下线")
- // 唯一胜出者的返回值必须等于 21(起点 20 → +1)
- winners.Range(func(_, v any) bool {
- assert.Equal(t, int64(21), v.(int64), "唯一胜出的 CAS 应返回 expected+1")
- return true
- })
- fresh, err := m.FindOne(ctx, id)
- require.NoError(t, err)
- assert.Equal(t, int64(21), fresh.TokenVersion, "DB 最终只能递增 1(CAS 原子性的外部可观察证据)")
- }
- // TC-0806: H-1 —— 成功后必须使 id-key / username-key 双路缓存失效,
- // 否则 middleware 读缓存拿到的 tokenVersion 与 DB 不一致,依然存在"旧令牌合法误放"的旁路。
- func TestSysUserModel_IncrementTokenVersionIfMatch_InvalidatesCaches(t *testing.T) {
- m, conn := newModel(t)
- ctx := context.Background()
- now := time.Now().Unix()
- username := "cas_cache_" + testutil.UniqueId()
- res, err := m.Insert(ctx, &user.SysUser{
- Username: username, Password: "x", Nickname: "n",
- Avatar: sql.NullString{}, IsSuperAdmin: 2, MustChangePassword: 2,
- Status: 1, TokenVersion: 0, CreateTime: now, UpdateTime: now,
- })
- require.NoError(t, err)
- id, _ := res.LastInsertId()
- t.Cleanup(func() { testutil.CleanTable(ctx, conn, "`sys_user`", id) })
- u0a, err := m.FindOne(ctx, id)
- require.NoError(t, err)
- require.Equal(t, int64(0), u0a.TokenVersion)
- u0b, err := m.FindOneByUsername(ctx, username)
- require.NoError(t, err)
- require.Equal(t, int64(0), u0b.TokenVersion)
- got, err := m.IncrementTokenVersionIfMatch(ctx, id, 0)
- require.NoError(t, err)
- require.Equal(t, int64(1), got)
- // 再次读两路缓存,必须看到递增后的 1(而非 stale 0)
- u1a, err := m.FindOne(ctx, id)
- require.NoError(t, err)
- assert.Equal(t, int64(1), u1a.TokenVersion, fmt.Sprintf(
- "id-key 缓存未被清理,stale tokenVersion=%d(审计 H-1 的缓存一致性防线)", u1a.TokenVersion))
- u1b, err := m.FindOneByUsername(ctx, username)
- require.NoError(t, err)
- assert.Equal(t, int64(1), u1b.TokenVersion, fmt.Sprintf(
- "username-key 缓存未被清理,stale tokenVersion=%d", u1b.TokenVersion))
- }
|