sysUserModel.go 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519
  1. package user
  2. import (
  3. "context"
  4. "database/sql"
  5. "errors"
  6. "fmt"
  7. "strings"
  8. "time"
  9. "github.com/zeromicro/go-zero/core/logx"
  10. "github.com/zeromicro/go-zero/core/stores/cache"
  11. "github.com/zeromicro/go-zero/core/stores/sqlx"
  12. )
  13. var ErrUpdateConflict = errors.New("update conflict: data has been modified by another operation")
  14. // ErrTokenVersionMismatch 表示令牌版本与数据库当前版本不一致,刷新令牌失败。
  15. // 典型场景:refreshToken rotation 并发到达 —— 只有持有当前 tokenVersion 的那一次能原子递增成功,
  16. // 其余全部返回该错误,防止两个请求都"换到"新令牌(导致会话劫持)。
  17. var ErrTokenVersionMismatch = errors.New("token version mismatch")
  18. var _ SysUserModel = (*customSysUserModel)(nil)
  19. type (
  20. UserListFilter struct {
  21. Username string
  22. Nickname string
  23. Status int64
  24. DeptIds []int64
  25. }
  26. SysUserModel interface {
  27. sysUserModel
  28. FindListByPage(ctx context.Context, page, pageSize int64) ([]*SysUser, int64, error)
  29. FindListByFilter(ctx context.Context, filter UserListFilter, page, pageSize int64) ([]*SysUser, int64, error)
  30. FindListByProductMembers(ctx context.Context, productCode string, page, pageSize int64) ([]*SysUser, map[int64]string, int64, error)
  31. FindByIds(ctx context.Context, ids []int64) ([]*SysUser, error)
  32. FindIdsByDeptId(ctx context.Context, deptId int64) ([]int64, error)
  33. // FindIdsByDeptIdForShareTx 在调用方事务里取目标 deptId 下所有用户 id,并对命中行加 S 锁
  34. // (LOCK IN SHARE MODE)。用于 UpdateDept 收窄(DEV→NORMAL / DEV 部门 Enabled→Disabled /
  35. // NORMAL 部门 Enabled→Disabled)后需要批量递增 tokenVersion 时,阻塞并发
  36. // UpdateProfileWithTx 对同一批 sys_user 行的 X 锁,避免"枚举 userIds 的瞬间有行刚被挪出
  37. // 本部门 / 新行刚被挪入"导致的漏吊销 / 多吊销(审计 L-R16-2)。
  38. // session==nil 返回错误。
  39. FindIdsByDeptIdForShareTx(ctx context.Context, session sqlx.Session, deptId int64) ([]int64, error)
  40. // UpdateProfile 更新用户资料字段(昵称 / 邮箱 / 手机 / 备注 / 部门 / 状态),username 仅用于
  41. // 构造旧缓存键 `cacheSysUserUsernamePrefix` 做失效,**不会**被写入 SET 子句。若未来确实需要
  42. // 修改 username,请独立实现 `UpdateUsernameTx`:
  43. // ① 同事务内做 old/new 两份 UsernameKey 的失效;
  44. // ② 捕获 1062 (UNIQUE 冲突) → response.ErrConflict,不得混进本方法的签名(审计 L-R11-3)。
  45. UpdateProfile(ctx context.Context, id int64, username string, nickname, email, phone, remark string, deptId, newStatus int64, statusChanged bool, expectedUpdateTime int64) error
  46. // UpdateProfileWithTx 与 UpdateProfile 的 SET 子句等价,但 UPDATE 执行在调用方传入的事务里;
  47. // 用于"改 deptId → 同事务内先 FOR SHARE sys_dept 目标行"修复 DeleteDept vs UpdateUser
  48. // 的 write skew(审计 M-R11-3)。
  49. // 审计 L-R12-1:本方法**不再**自己做缓存失效——go-zero `m.ExecCtx` 的 DelCache 会在闭包
  50. // 成功返回时立即发生,这里"成功"只代表 session.ExecCtx 收到了行数,事务仍在进行中,
  51. // 并发 FindOne 会把尚未提交的旧值重新灌回 sysUser 低层缓存造成 stale。调用方**必须**在
  52. // TransactCtx 返回(事务 commit)成功之后调用 InvalidateProfileCache(id, username) 做
  53. // post-commit 失效;同模式对任何 `*WithTx` 方法都成立。
  54. UpdateProfileWithTx(ctx context.Context, session sqlx.Session, id int64, username string, nickname, email, phone, remark string, deptId, newStatus int64, statusChanged bool, expectedUpdateTime int64) error
  55. // InvalidateProfileCache 失效 sysUser 的 id / username 两把低层缓存键。仅应在事务
  56. // commit 成功后由调用方显式调用(见 UpdateProfileWithTx 说明)。best-effort:失效失败
  57. // 只留日志,最终由 TTL 兜底,不影响事务已完成的业务结果(审计 L-R12-1)。
  58. InvalidateProfileCache(ctx context.Context, id int64, username string)
  59. // UpdatePassword 审计 H-R11-1:expectedUpdateTime 必须由调用方用**外层校验旧密码时拿到的
  60. // 那一份 updateTime** 显式透传;禁止函数内部再 FindOne 自对齐乐观锁,否则内层 CAS 等于退化
  61. // 为 last-write-wins,被"旧会话 + 知道旧密码"的攻击者可以把管理员紧急改过的新密码盖回。
  62. // username 仅用于构造 `cacheSysUserUsernamePrefix` 缓存键失效(sysUser.username 唯一,本函数
  63. // 不会修改它)。
  64. UpdatePassword(ctx context.Context, id int64, username string, password string, mustChangePassword, expectedUpdateTime int64) error
  65. // UpdateStatus 审计 M-R11-2:username 由调用方透传,避免仅为构造缓存键而多打一次 FindOne。
  66. UpdateStatus(ctx context.Context, id int64, username string, status int64, expectedUpdateTime int64) error
  67. // IncrementTokenVersion 审计 M-R11-2:username 由调用方透传,避免仅为构造缓存键而多打一次 FindOne。
  68. IncrementTokenVersion(ctx context.Context, id int64, username string) (int64, error)
  69. IncrementTokenVersionIfMatch(ctx context.Context, id int64, username string, expected int64) (int64, error)
  70. // IncrementTokenVersionWithTx 在调用方提供的事务内递增 sys_user.tokenVersion,供 UpdateMember
  71. // 降级 / 禁用等业务把旧 access/refresh token 立刻打到 middleware 的 `claims.TokenVersion !=
  72. // ud.TokenVersion` / refresh CAS 的拒绝分支(审计 M-R15-1 方案 A)。
  73. //
  74. // 与非事务版本 IncrementTokenVersion 的差异:
  75. // - 不在方法内做任何缓存失效;tx 成功提交后由调用方走 InvalidateProfileCache(id, username)
  76. // + UserDetailsLoader.Del/CleanByProduct 的 post-commit 链路(对齐 UpdateProfileWithTx
  77. // 的 L-R12-1 契约;若事务回滚,调用方**不要**触发失效,避免"操作失败但用户被踢下线");
  78. // - session==nil 直接报错——非事务场景请改走 IncrementTokenVersion。
  79. // 返回新的 tokenVersion(SELECT LAST_INSERT_ID())供调用方 log / forensic 比对。
  80. IncrementTokenVersionWithTx(ctx context.Context, session sqlx.Session, id int64) (int64, error)
  81. // BatchIncrementTokenVersionWithTx 批量递增一组 userIds 的 tokenVersion,供 UpdateProduct
  82. // 禁用时吊销该产品全体成员已签发的 access/refresh(审计 L-R15-3)。
  83. //
  84. // 契约:
  85. // - 不做缓存失效;tx 提交后调用方按 (id, username) 对逐个 InvalidateProfileCache,UD 聚合
  86. // 缓存通过 UserDetailsLoader.CleanByProduct 失效;
  87. // - 调用方负责对 ids 去重并控制长度,空切片直接返回 nil;
  88. // - session==nil 返回错误。
  89. BatchIncrementTokenVersionWithTx(ctx context.Context, session sqlx.Session, ids []int64) error
  90. // UpdateSelfInfo 更新用户自身安全字段(昵称/头像/邮箱/手机),不涉及 deptId/status/tokenVersion。
  91. // username 仅用于构造缓存键失效。avatar 使用 sql.NullString 语义写入(空串置 NULL)。
  92. UpdateSelfInfo(ctx context.Context, id int64, username string, nickname, avatar, email, phone string, expectedUpdateTime int64) error
  93. }
  94. customSysUserModel struct {
  95. *defaultSysUserModel
  96. }
  97. )
  98. func NewSysUserModel(conn sqlx.SqlConn, c cache.CacheConf, cachePrefix string, opts ...cache.Option) SysUserModel {
  99. return &customSysUserModel{
  100. defaultSysUserModel: newSysUserModel(conn, c, cachePrefix, opts...),
  101. }
  102. }
  103. func (m *customSysUserModel) FindListByPage(ctx context.Context, page, pageSize int64) ([]*SysUser, int64, error) {
  104. var total int64
  105. countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s", m.table)
  106. if err := m.QueryRowNoCacheCtx(ctx, &total, countQuery); err != nil {
  107. return nil, 0, err
  108. }
  109. var list []*SysUser
  110. query := fmt.Sprintf("SELECT %s FROM %s ORDER BY id DESC LIMIT ?,?", sysUserRows, m.table)
  111. if err := m.QueryRowsNoCacheCtx(ctx, &list, query, (page-1)*pageSize, pageSize); err != nil {
  112. return nil, 0, err
  113. }
  114. return list, total, nil
  115. }
  116. func (m *customSysUserModel) FindListByFilter(ctx context.Context, filter UserListFilter, page, pageSize int64) ([]*SysUser, int64, error) {
  117. var conditions []string
  118. var args []interface{}
  119. if filter.Username != "" {
  120. conditions = append(conditions, "`username` LIKE ?")
  121. args = append(args, "%"+filter.Username+"%")
  122. }
  123. if filter.Nickname != "" {
  124. conditions = append(conditions, "`nickname` LIKE ?")
  125. args = append(args, "%"+filter.Nickname+"%")
  126. }
  127. if filter.Status > 0 {
  128. conditions = append(conditions, "`status` = ?")
  129. args = append(args, filter.Status)
  130. }
  131. if len(filter.DeptIds) > 0 {
  132. placeholders := strings.Repeat("?,", len(filter.DeptIds))
  133. placeholders = placeholders[:len(placeholders)-1]
  134. conditions = append(conditions, "`deptId` IN ("+placeholders+")")
  135. for _, id := range filter.DeptIds {
  136. args = append(args, id)
  137. }
  138. }
  139. where := ""
  140. if len(conditions) > 0 {
  141. where = " WHERE " + strings.Join(conditions, " AND ")
  142. }
  143. var total int64
  144. countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s%s", m.table, where)
  145. if err := m.QueryRowNoCacheCtx(ctx, &total, countQuery, args...); err != nil {
  146. return nil, 0, err
  147. }
  148. var list []*SysUser
  149. queryArgs := append(args, (page-1)*pageSize, pageSize)
  150. query := fmt.Sprintf("SELECT %s FROM %s%s ORDER BY id DESC LIMIT ?,?", sysUserRows, m.table, where)
  151. if err := m.QueryRowsNoCacheCtx(ctx, &list, query, queryArgs...); err != nil {
  152. return nil, 0, err
  153. }
  154. return list, total, nil
  155. }
  156. type UserWithMemberType struct {
  157. SysUser
  158. MemberType string `db:"memberType"`
  159. }
  160. func (m *customSysUserModel) FindListByProductMembers(ctx context.Context, productCode string, page, pageSize int64) ([]*SysUser, map[int64]string, int64, error) {
  161. memberTable := "`sys_product_member`"
  162. var total int64
  163. countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s u INNER JOIN %s pm ON u.`id` = pm.`userId` WHERE pm.`productCode` = ?", m.table, memberTable)
  164. if err := m.QueryRowNoCacheCtx(ctx, &total, countQuery, productCode); err != nil {
  165. return nil, nil, 0, err
  166. }
  167. var list []*UserWithMemberType
  168. fields := strings.Join(sysUserFieldNames, ",u.")
  169. query := fmt.Sprintf("SELECT u.%s, pm.`memberType` FROM %s u INNER JOIN %s pm ON u.`id` = pm.`userId` WHERE pm.`productCode` = ? ORDER BY u.`id` DESC LIMIT ?,?", fields, m.table, memberTable)
  170. if err := m.QueryRowsNoCacheCtx(ctx, &list, query, productCode, (page-1)*pageSize, pageSize); err != nil {
  171. return nil, nil, 0, err
  172. }
  173. users := make([]*SysUser, len(list))
  174. memberMap := make(map[int64]string, len(list))
  175. for i, item := range list {
  176. users[i] = &item.SysUser
  177. memberMap[item.Id] = item.MemberType
  178. }
  179. return users, memberMap, total, nil
  180. }
  181. func (m *customSysUserModel) FindIdsByDeptId(ctx context.Context, deptId int64) ([]int64, error) {
  182. var ids []int64
  183. query := fmt.Sprintf("SELECT `id` FROM %s WHERE `deptId` = ?", m.table)
  184. if err := m.QueryRowsNoCacheCtx(ctx, &ids, query, deptId); err != nil {
  185. return nil, err
  186. }
  187. return ids, nil
  188. }
  189. // FindIdsByDeptIdForShareTx 见接口注释(审计 L-R16-2)。
  190. func (m *customSysUserModel) FindIdsByDeptIdForShareTx(ctx context.Context, session sqlx.Session, deptId int64) ([]int64, error) {
  191. if session == nil {
  192. return nil, errors.New("FindIdsByDeptIdForShareTx requires a non-nil session")
  193. }
  194. var ids []int64
  195. // ORDER BY id 让 S 锁的申请顺序确定,与并发的 UpdateProfileWithTx(按 id 触发 X 锁)交叉时
  196. // 形成固定加锁偏序,减小死锁概率;同时方便上层日志对照。
  197. query := fmt.Sprintf("SELECT `id` FROM %s WHERE `deptId` = ? ORDER BY `id` LOCK IN SHARE MODE", m.table)
  198. if err := session.QueryRowsCtx(ctx, &ids, query, deptId); err != nil {
  199. return nil, err
  200. }
  201. return ids, nil
  202. }
  203. func (m *customSysUserModel) UpdateProfile(ctx context.Context, id int64, username string, nickname, email, phone, remark string, deptId, newStatus int64, statusChanged bool, expectedUpdateTime int64) error {
  204. sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id)
  205. sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, username)
  206. now := time.Now().Unix()
  207. res, err := m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (sql.Result, error) {
  208. if statusChanged {
  209. query := fmt.Sprintf("UPDATE %s SET `nickname`=?, `email`=?, `phone`=?, `remark`=?, `deptId`=?, `status`=?, `tokenVersion`=`tokenVersion`+1, `updateTime`=? WHERE `id`=? AND `updateTime`=?", m.table)
  210. return conn.ExecCtx(ctx, query, nickname, email, phone, remark, deptId, newStatus, now, id, expectedUpdateTime)
  211. }
  212. query := fmt.Sprintf("UPDATE %s SET `nickname`=?, `email`=?, `phone`=?, `remark`=?, `deptId`=?, `updateTime`=? WHERE `id`=? AND `updateTime`=?", m.table)
  213. return conn.ExecCtx(ctx, query, nickname, email, phone, remark, deptId, now, id, expectedUpdateTime)
  214. }, sysUserIdKey, sysUserUsernameKey)
  215. if err != nil {
  216. return err
  217. }
  218. affected, _ := res.RowsAffected()
  219. if affected == 0 {
  220. return ErrUpdateConflict
  221. }
  222. return nil
  223. }
  224. func (m *customSysUserModel) UpdateSelfInfo(ctx context.Context, id int64, username string, nickname, avatar, email, phone string, expectedUpdateTime int64) error {
  225. sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id)
  226. sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, username)
  227. now := time.Now().Unix()
  228. var avatarVal sql.NullString
  229. if avatar != "" {
  230. avatarVal = sql.NullString{String: avatar, Valid: true}
  231. }
  232. res, err := m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (sql.Result, error) {
  233. query := fmt.Sprintf("UPDATE %s SET `nickname`=?, `avatar`=?, `email`=?, `phone`=?, `updateTime`=? WHERE `id`=? AND `updateTime`=?", m.table)
  234. return conn.ExecCtx(ctx, query, nickname, avatarVal, email, phone, now, id, expectedUpdateTime)
  235. }, sysUserIdKey, sysUserUsernameKey)
  236. if err != nil {
  237. return err
  238. }
  239. affected, _ := res.RowsAffected()
  240. if affected == 0 {
  241. return ErrUpdateConflict
  242. }
  243. return nil
  244. }
  245. // UpdateProfileWithTx 见接口注释(审计 M-R11-3 + L-R12-1)。
  246. // 实现上**绕过** m.ExecCtx 的 pre-commit DelCache 语义——仅调用 session.ExecCtx,缓存失效由
  247. // 调用方在事务 commit 成功后显式走 InvalidateProfileCache。
  248. // session==nil 时直接拒绝(非事务场景必须改走 UpdateProfile)。
  249. func (m *customSysUserModel) UpdateProfileWithTx(ctx context.Context, session sqlx.Session, id int64, username string, nickname, email, phone, remark string, deptId, newStatus int64, statusChanged bool, expectedUpdateTime int64) error {
  250. if session == nil {
  251. return errors.New("UpdateProfileWithTx requires a non-nil session")
  252. }
  253. _ = username // 保留形参以维持调用方契约一致,失效由 InvalidateProfileCache 另行处理
  254. now := time.Now().Unix()
  255. var (
  256. res sql.Result
  257. err error
  258. )
  259. if statusChanged {
  260. query := fmt.Sprintf("UPDATE %s SET `nickname`=?, `email`=?, `phone`=?, `remark`=?, `deptId`=?, `status`=?, `tokenVersion`=`tokenVersion`+1, `updateTime`=? WHERE `id`=? AND `updateTime`=?", m.table)
  261. res, err = session.ExecCtx(ctx, query, nickname, email, phone, remark, deptId, newStatus, now, id, expectedUpdateTime)
  262. } else {
  263. query := fmt.Sprintf("UPDATE %s SET `nickname`=?, `email`=?, `phone`=?, `remark`=?, `deptId`=?, `updateTime`=? WHERE `id`=? AND `updateTime`=?", m.table)
  264. res, err = session.ExecCtx(ctx, query, nickname, email, phone, remark, deptId, now, id, expectedUpdateTime)
  265. }
  266. if err != nil {
  267. return err
  268. }
  269. affected, _ := res.RowsAffected()
  270. if affected == 0 {
  271. return ErrUpdateConflict
  272. }
  273. return nil
  274. }
  275. // InvalidateProfileCache 见接口注释(审计 L-R12-1)。这里不暴露错误——post-commit 阶段
  276. // DB 已为权威,缓存清理失败只会让 sysUser 低层缓存继续提供旧值直至 TTL 过期,而 Logic 层
  277. // 紧接其后的 UserDetailsLoader.Clean 已经失效了上层 UserDetails 聚合缓存;两级缓存一起过期
  278. // 的最坏情形仍然等价于 loadUser 的 cache-miss 路径,不会放大故障面。
  279. func (m *customSysUserModel) InvalidateProfileCache(ctx context.Context, id int64, username string) {
  280. sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id)
  281. sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, username)
  282. if err := m.DelCacheCtx(ctx, sysUserIdKey, sysUserUsernameKey); err != nil {
  283. // 审计 L-R13-5 方案 B:失败原因拆成 ctx 取消 vs 其它两档——前者打独立 audit tag 方便
  284. // 运维按 `cache_invalidation_skipped_due_to_ctx_cancel` 建看板,避免与真正的 Redis 故障
  285. // 混在一起报警;其它错误仍然 Errorf,保持与 sqlc 原生失效路径(Insert/Update 触发的
  286. // DelCache 失败)一致的可观测性口径。
  287. if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
  288. logx.WithContext(ctx).Errorw("cache invalidation skipped: ctx canceled",
  289. logx.Field("audit", "cache_invalidation_skipped_due_to_ctx_cancel"),
  290. logx.Field("scope", "sysUserModel.InvalidateProfileCache"),
  291. logx.Field("id", id),
  292. logx.Field("err", err.Error()),
  293. )
  294. } else {
  295. logx.WithContext(ctx).Errorf("sysUserModel.InvalidateProfileCache failed: id=%d err=%v", id, err)
  296. }
  297. }
  298. }
  299. func (m *customSysUserModel) UpdatePassword(ctx context.Context, id int64, username string, password string, mustChangePassword, expectedUpdateTime int64) error {
  300. sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id)
  301. sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, username)
  302. // 审计 H-R11-1:expectedUpdateTime 必须来自**外层校验旧密码时读到的**快照。不要再内部 FindOne
  303. // 自取 data.UpdateTime —— 那会让 CAS 在本函数内自我对齐,退化为 last-write-wins,让"会话持有 +
  304. // 知道旧密码"的攻击者可以把 admin 紧急改过的新密码盖回到自己手里那一份。
  305. res, err := m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (sql.Result, error) {
  306. query := fmt.Sprintf("UPDATE %s SET `password` = ?, `mustChangePassword` = ?, `tokenVersion` = `tokenVersion` + 1, `updateTime` = ? WHERE `id` = ? AND `updateTime` = ?", m.table)
  307. return conn.ExecCtx(ctx, query, password, mustChangePassword, time.Now().Unix(), id, expectedUpdateTime)
  308. }, sysUserIdKey, sysUserUsernameKey)
  309. if err != nil {
  310. return err
  311. }
  312. if affected, _ := res.RowsAffected(); affected == 0 {
  313. // 行被删除或被并发改过(任何字段,包括 status/password/profile):统一回 ErrUpdateConflict。
  314. return ErrUpdateConflict
  315. }
  316. return nil
  317. }
  318. // UpdateStatus 修改用户 status,并强制递增 tokenVersion 让已签发令牌失效。
  319. // 审计 L-N4:WHERE 必须带 `updateTime=?` 乐观锁,和 UpdateProfile / UpdatePassword 语义对齐。
  320. // 上游 UpdateUserStatusLogic 已经从 ValidateStatusChange 拿到 sysUser,调用方应把
  321. // `sysUser.UpdateTime` 当作 expectedUpdateTime 传入:
  322. // - expectedUpdateTime 不匹配 → ErrUpdateConflict;上层统一回 409 "数据已被其他操作修改"。
  323. // - 避免并发冻结/解冻请求走"last-write-wins",出现两个 admin 同时点"冻结"/"解冻"
  324. // 时后到者覆盖先到者、tokenVersion 被连续加两次把刚刚解冻的用户再次踢下线的诡异现象。
  325. //
  326. // 审计 L-R17-6 · 无条件递增 tokenVersion 是**刻意**设计(不要改成"仅冻结时 +1"的条件递增):
  327. // - Enabled→Disabled(冻结):签发层吊销旧 access/refresh token,JWT middleware 凭
  328. // `claims.tokenVersion < DB.tokenVersion` 一票否决。
  329. // - Disabled→Enabled(解冻):用户已因 Status!=Enabled 无法登录/刷新,+1 在业务上是空动作;
  330. // 但保留 +1 覆盖"冻结 → Redis Clean 失败 → 解冻"这条极窄路径里"旧 access token 凭残留
  331. // UD 复活"的可能,与 UpdateUserStatusLogic 的 L-R17-6 注释同口径。
  332. func (m *customSysUserModel) UpdateStatus(ctx context.Context, id int64, username string, status int64, expectedUpdateTime int64) error {
  333. // 审计 M-R11-2:username 由调用方(ValidateStatusChange 返回的目标用户对象)显式透传,
  334. // 不再内部 FindOne。真实并发安全继续靠 `WHERE updateTime = expectedUpdateTime` 乐观锁兜底。
  335. sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id)
  336. sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, username)
  337. res, err := m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (sql.Result, error) {
  338. query := fmt.Sprintf("UPDATE %s SET `status` = ?, `tokenVersion` = `tokenVersion` + 1, `updateTime` = ? WHERE `id` = ? AND `updateTime` = ?", m.table)
  339. return conn.ExecCtx(ctx, query, status, time.Now().Unix(), id, expectedUpdateTime)
  340. }, sysUserIdKey, sysUserUsernameKey)
  341. if err != nil {
  342. return err
  343. }
  344. if affected, _ := res.RowsAffected(); affected == 0 {
  345. // 行被删除或被并发改过:对外统一回 ErrUpdateConflict,避免对已删除 / 被并发改过的行
  346. // 返回 nil 让上层误判为"冻结生效"(审计 M-2 / L-N4)。
  347. return ErrUpdateConflict
  348. }
  349. return nil
  350. }
  351. // IncrementTokenVersion 强制递增当前用户的 tokenVersion,让**所有**已签发的 access/refresh 立即失效。
  352. //
  353. // WARN: 仅限"强制全量会话失效"场景调用——主动 Logout 或封禁/重置密码。Refresh / Rotate 场景
  354. // 必须调用 IncrementTokenVersionIfMatch 走 CAS 语义,否则会回到 R5 以前的并发 rotate 窗口,
  355. // 两次并发 refresh 都能换到新令牌,等同于会话劫持(见审计 L-2)。
  356. // 调用前请先走 TokenOpLimiter 等限流,避免被反复触发把合法用户 kick 出登录。
  357. func (m *customSysUserModel) IncrementTokenVersion(ctx context.Context, id int64, username string) (int64, error) {
  358. // 审计 M-R11-2:username 由调用方显式透传,不再内部 FindOne 仅为构造缓存键。
  359. // 调用方(Logout 唯一入口)在 UserDetailsLoader.Load 之后天然已经持有 ud.Username。
  360. sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id)
  361. sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, username)
  362. var newVersion int64
  363. err := m.TransactCtx(ctx, func(ctx context.Context, session sqlx.Session) error {
  364. query := fmt.Sprintf("UPDATE %s SET `tokenVersion` = LAST_INSERT_ID(`tokenVersion` + 1), `updateTime` = ? WHERE `id` = ?", m.table)
  365. res, err := session.ExecCtx(ctx, query, time.Now().Unix(), id)
  366. if err != nil {
  367. return err
  368. }
  369. // 审计 L-R10-3:FindOne 与本次 UPDATE 之间若有并发 DELETE,affected=0 仍会返回 nil 让
  370. // SELECT LAST_INSERT_ID() 在全新事务里读出 0,调用方无法区分"真的递增成功"和"目标已被删除"。
  371. // 对外统一回 ErrUpdateConflict,调用方据此做日志/告警或跳过后续 Clean,避免对 tokenVersion=0
  372. // 之类的伪返回值做错误假设(如"版本从 1 起递增"的默认契约踩坑)。
  373. if affected, _ := res.RowsAffected(); affected == 0 {
  374. return ErrUpdateConflict
  375. }
  376. return session.QueryRowCtx(ctx, &newVersion, "SELECT LAST_INSERT_ID()")
  377. })
  378. if err != nil {
  379. return 0, err
  380. }
  381. _ = m.DelCacheCtx(ctx, sysUserIdKey, sysUserUsernameKey)
  382. return newVersion, nil
  383. }
  384. // IncrementTokenVersionIfMatch 原子递增 tokenVersion;仅当 DB 里当前 tokenVersion == expected 时才会生效。
  385. // 这是 refreshToken rotation 的原子 CAS:两个并发的刷新请求只有一个能命中 WHERE tokenVersion=expected,
  386. // 另一个 affected=0 返回 ErrTokenVersionMismatch,从而避免"两边都换到新令牌"的会话劫持窗口。
  387. //
  388. // 由上游透传 username 以便构造 cacheSysUserUsernamePrefix 的缓存键进行失效,避免为此多查一次 FindOne
  389. // (见审计 M-8)。上游通常已经通过 UserDetailsLoader.Load 拿到 username,零额外成本。
  390. func (m *customSysUserModel) IncrementTokenVersionIfMatch(ctx context.Context, id int64, username string, expected int64) (int64, error) {
  391. sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id)
  392. sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, username)
  393. var newVersion int64
  394. err := m.TransactCtx(ctx, func(ctx context.Context, session sqlx.Session) error {
  395. query := fmt.Sprintf("UPDATE %s SET `tokenVersion` = LAST_INSERT_ID(`tokenVersion` + 1), `updateTime` = ? WHERE `id` = ? AND `tokenVersion` = ?", m.table)
  396. res, err := session.ExecCtx(ctx, query, time.Now().Unix(), id, expected)
  397. if err != nil {
  398. return err
  399. }
  400. affected, _ := res.RowsAffected()
  401. if affected == 0 {
  402. return ErrTokenVersionMismatch
  403. }
  404. return session.QueryRowCtx(ctx, &newVersion, "SELECT LAST_INSERT_ID()")
  405. })
  406. if err != nil {
  407. return 0, err
  408. }
  409. _ = m.DelCacheCtx(ctx, sysUserIdKey, sysUserUsernameKey)
  410. return newVersion, nil
  411. }
  412. // IncrementTokenVersionWithTx 见接口注释(审计 M-R15-1 方案 A)。
  413. func (m *customSysUserModel) IncrementTokenVersionWithTx(ctx context.Context, session sqlx.Session, id int64) (int64, error) {
  414. if session == nil {
  415. return 0, errors.New("IncrementTokenVersionWithTx requires a non-nil session")
  416. }
  417. query := fmt.Sprintf("UPDATE %s SET `tokenVersion` = LAST_INSERT_ID(`tokenVersion` + 1), `updateTime` = ? WHERE `id` = ?", m.table)
  418. res, err := session.ExecCtx(ctx, query, time.Now().Unix(), id)
  419. if err != nil {
  420. return 0, err
  421. }
  422. // 与 IncrementTokenVersion 的 L-R10-3 契约一致:affected=0 表示目标已被并发删除,回
  423. // ErrUpdateConflict 让上层业务事务自行 rollback——不递增 tokenVersion 比"哑失败后继续走
  424. // post-commit 失效链"安全(后者会把已不存在的 userId 从缓存里删除,没有副作用但污染日志)。
  425. if affected, _ := res.RowsAffected(); affected == 0 {
  426. return 0, ErrUpdateConflict
  427. }
  428. var newVersion int64
  429. if err := session.QueryRowCtx(ctx, &newVersion, "SELECT LAST_INSERT_ID()"); err != nil {
  430. return 0, err
  431. }
  432. return newVersion, nil
  433. }
  434. // BatchIncrementTokenVersionWithTx 见接口注释(审计 L-R15-3)。
  435. func (m *customSysUserModel) BatchIncrementTokenVersionWithTx(ctx context.Context, session sqlx.Session, ids []int64) error {
  436. if session == nil {
  437. return errors.New("BatchIncrementTokenVersionWithTx requires a non-nil session")
  438. }
  439. if len(ids) == 0 {
  440. return nil
  441. }
  442. placeholders := make([]string, len(ids))
  443. args := make([]interface{}, 0, len(ids)+1)
  444. args = append(args, time.Now().Unix())
  445. for i, id := range ids {
  446. placeholders[i] = "?"
  447. args = append(args, id)
  448. }
  449. // 批量 UPDATE 不再回读每行新 tokenVersion——调用方(UpdateProduct 禁用)只关心"集体递增
  450. // 发生了",不做 forensic 比对;若未来需要逐行返回,请改走 IN(..) + SELECT 的两步模式,
  451. // 不要试图让 LAST_INSERT_ID() 承担 N 行的反向通道(它只会保留最后一次赋值)。
  452. query := fmt.Sprintf("UPDATE %s SET `tokenVersion` = `tokenVersion` + 1, `updateTime` = ? WHERE `id` IN (%s)", m.table, strings.Join(placeholders, ","))
  453. if _, err := session.ExecCtx(ctx, query, args...); err != nil {
  454. return err
  455. }
  456. return nil
  457. }
  458. func (m *customSysUserModel) FindByIds(ctx context.Context, ids []int64) ([]*SysUser, error) {
  459. if len(ids) == 0 {
  460. return nil, nil
  461. }
  462. placeholders := make([]string, len(ids))
  463. args := make([]interface{}, len(ids))
  464. for i, id := range ids {
  465. placeholders[i] = "?"
  466. args[i] = id
  467. }
  468. var list []*SysUser
  469. query := fmt.Sprintf("SELECT %s FROM %s WHERE `id` IN (%s)", sysUserRows, m.table, strings.Join(placeholders, ","))
  470. if err := m.QueryRowsNoCacheCtx(ctx, &list, query, args...); err != nil {
  471. return nil, err
  472. }
  473. return list, nil
  474. }