package user_test import ( "context" "database/sql" "errors" "fmt" "sync" "sync/atomic" "testing" "time" "perms-system-server/internal/model/user" "perms-system-server/internal/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // --------------------------------------------------------------------------- // 覆盖目标:审计 H-1 修复 —— IncrementTokenVersionIfMatch 必须是"当 DB.tokenVersion == expected // 时才原子递增"的 CAS,否则 refreshToken rotation 在并发刷新时会放出"两枚都合法的新令牌"导致会话劫持。 // 以下用例把这一契约显式钉死: // - 匹配 → 成功并返回递增后的新版本 // - 不匹配 → ErrTokenVersionMismatch(不能返回成功,不能悄悄递增) // - 并发竞态 → N 个 goroutine 用同一个 expected 打入,必须只有 1 个成功 // - 成功后必须清掉 id-key / username-key 双路缓存 // --------------------------------------------------------------------------- // TC-0802: H-1 —— expected 与 DB 当前 tokenVersion 一致时返回递增后的新版本。 func TestSysUserModel_IncrementTokenVersionIfMatch_Match(t *testing.T) { m, conn := newModel(t) ctx := context.Background() now := time.Now().Unix() username := "cas_match_" + testutil.UniqueId() res, err := m.Insert(ctx, &user.SysUser{ Username: username, Password: "x", Nickname: "n", Avatar: sql.NullString{}, IsSuperAdmin: 2, MustChangePassword: 2, Status: 1, TokenVersion: 5, CreateTime: now, UpdateTime: now, }) require.NoError(t, err) id, _ := res.LastInsertId() t.Cleanup(func() { testutil.CleanTable(ctx, conn, "`sys_user`", id) }) got, err := m.IncrementTokenVersionIfMatch(ctx, id, 5) require.NoError(t, err) assert.Equal(t, int64(6), got, "expected 命中时返回 DB 真实递增后的新版本") fresh, err := m.FindOne(ctx, id) require.NoError(t, err) assert.Equal(t, int64(6), fresh.TokenVersion, "DB 落盘值必须也是 6") } // TC-0803: H-1 —— expected 与 DB 不一致时返回 ErrTokenVersionMismatch 且 DB 不得发生任何变更。 // 这是会话劫持窗口的关键拦截:攻击者的 token 里 TokenVersion = V,但合法用户已刷新到 V+1, // 攻击者再来刷新时 expected=V 打不中 WHERE 子句 → 必须失败。 func TestSysUserModel_IncrementTokenVersionIfMatch_Mismatch_NoSideEffect(t *testing.T) { m, conn := newModel(t) ctx := context.Background() now := time.Now().Unix() username := "cas_mismatch_" + testutil.UniqueId() res, err := m.Insert(ctx, &user.SysUser{ Username: username, Password: "x", Nickname: "n", Avatar: sql.NullString{}, IsSuperAdmin: 2, MustChangePassword: 2, Status: 1, TokenVersion: 10, CreateTime: now, UpdateTime: now, }) require.NoError(t, err) id, _ := res.LastInsertId() t.Cleanup(func() { testutil.CleanTable(ctx, conn, "`sys_user`", id) }) got, err := m.IncrementTokenVersionIfMatch(ctx, id, 9) require.Error(t, err, "expected 未命中时必须返回错误") assert.True(t, errors.Is(err, user.ErrTokenVersionMismatch), "错误必须是 ErrTokenVersionMismatch 以供 logic 层分辨") assert.Equal(t, int64(0), got) fresh, err := m.FindOne(ctx, id) require.NoError(t, err) assert.Equal(t, int64(10), fresh.TokenVersion, "CAS 失败必须对 DB 零副作用") } // TC-0804: H-1 —— user 不存在时必须返回原生 NotFound 错误(不得被 ErrTokenVersionMismatch 掩盖)。 // 这个边界保证 logic 层能区分"用户被删"(应走 UserDetailsLoader 的 status 分支)和"令牌被接管"。 func TestSysUserModel_IncrementTokenVersionIfMatch_UserNotFound(t *testing.T) { m, _ := newModel(t) ctx := context.Background() got, err := m.IncrementTokenVersionIfMatch(ctx, 999999998, 0) require.Error(t, err) assert.False(t, errors.Is(err, user.ErrTokenVersionMismatch), "用户不存在的错误不得伪装成 TokenVersionMismatch,避免混淆 logic 层的分支") assert.Equal(t, int64(0), got) } // TC-0805: H-1 并发回归 —— N 个 goroutine 用同一个 expected 去 CAS, // 必须恰好只有 1 个返回 success,其余全部 ErrTokenVersionMismatch; // 最终 DB 的 tokenVersion 必须只递增 1(攻击者无法劫持第二枚令牌)。 func TestSysUserModel_IncrementTokenVersionIfMatch_ConcurrentSingleWinner(t *testing.T) { m, conn := newModel(t) ctx := context.Background() now := time.Now().Unix() username := "cas_race_" + testutil.UniqueId() res, err := m.Insert(ctx, &user.SysUser{ Username: username, Password: "x", Nickname: "n", Avatar: sql.NullString{}, IsSuperAdmin: 2, MustChangePassword: 2, Status: 1, TokenVersion: 20, CreateTime: now, UpdateTime: now, }) require.NoError(t, err) id, _ := res.LastInsertId() t.Cleanup(func() { testutil.CleanTable(ctx, conn, "`sys_user`", id) }) // 限制在 8 并发以避免触发 go-zero sqlx breaker(单机 MySQL + breaker 对同批次突发 // 的并发 UPDATE 容易误伤;CAS 契约在 N=8 时已足以验证"唯一胜出")。 const N = 8 var ( wg sync.WaitGroup successCnt int32 mismatchCnt int32 otherErr atomic.Value winners sync.Map ) start := make(chan struct{}) for i := 0; i < N; i++ { wg.Add(1) go func(idx int) { defer wg.Done() <-start // 最大程度对齐并发起跑线 v, e := m.IncrementTokenVersionIfMatch(ctx, id, 20) switch { case e == nil: atomic.AddInt32(&successCnt, 1) winners.Store(idx, v) case errors.Is(e, user.ErrTokenVersionMismatch): atomic.AddInt32(&mismatchCnt, 1) default: otherErr.Store(e) } }(i) } close(start) wg.Wait() if v := otherErr.Load(); v != nil { t.Fatalf("并发 CAS 出现非预期错误:%v", v) } assert.Equal(t, int32(1), atomic.LoadInt32(&successCnt), "会话劫持防线:N=16 的竞态中必须有且仅有 1 个 CAS 胜出") assert.Equal(t, int32(N-1), atomic.LoadInt32(&mismatchCnt), "其他并发者必须全部返回 ErrTokenVersionMismatch,即攻击者会被 401 下线") // 唯一胜出者的返回值必须等于 21(起点 20 → +1) winners.Range(func(_, v any) bool { assert.Equal(t, int64(21), v.(int64), "唯一胜出的 CAS 应返回 expected+1") return true }) fresh, err := m.FindOne(ctx, id) require.NoError(t, err) assert.Equal(t, int64(21), fresh.TokenVersion, "DB 最终只能递增 1(CAS 原子性的外部可观察证据)") } // TC-0806: H-1 —— 成功后必须使 id-key / username-key 双路缓存失效, // 否则 middleware 读缓存拿到的 tokenVersion 与 DB 不一致,依然存在"旧令牌合法误放"的旁路。 func TestSysUserModel_IncrementTokenVersionIfMatch_InvalidatesCaches(t *testing.T) { m, conn := newModel(t) ctx := context.Background() now := time.Now().Unix() username := "cas_cache_" + testutil.UniqueId() res, err := m.Insert(ctx, &user.SysUser{ Username: username, Password: "x", Nickname: "n", Avatar: sql.NullString{}, IsSuperAdmin: 2, MustChangePassword: 2, Status: 1, TokenVersion: 0, CreateTime: now, UpdateTime: now, }) require.NoError(t, err) id, _ := res.LastInsertId() t.Cleanup(func() { testutil.CleanTable(ctx, conn, "`sys_user`", id) }) u0a, err := m.FindOne(ctx, id) require.NoError(t, err) require.Equal(t, int64(0), u0a.TokenVersion) u0b, err := m.FindOneByUsername(ctx, username) require.NoError(t, err) require.Equal(t, int64(0), u0b.TokenVersion) got, err := m.IncrementTokenVersionIfMatch(ctx, id, 0) require.NoError(t, err) require.Equal(t, int64(1), got) // 再次读两路缓存,必须看到递增后的 1(而非 stale 0) u1a, err := m.FindOne(ctx, id) require.NoError(t, err) assert.Equal(t, int64(1), u1a.TokenVersion, fmt.Sprintf( "id-key 缓存未被清理,stale tokenVersion=%d(审计 H-1 的缓存一致性防线)", u1a.TokenVersion)) u1b, err := m.FindOneByUsername(ctx, username) require.NoError(t, err) assert.Equal(t, int64(1), u1b.TokenVersion, fmt.Sprintf( "username-key 缓存未被清理,stale tokenVersion=%d", u1b.TokenVersion)) }