sysUserModel.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. package user
  2. import (
  3. "context"
  4. "database/sql"
  5. "errors"
  6. "fmt"
  7. "strings"
  8. "time"
  9. "github.com/zeromicro/go-zero/core/stores/cache"
  10. "github.com/zeromicro/go-zero/core/stores/sqlx"
  11. )
  12. var ErrUpdateConflict = errors.New("update conflict: data has been modified by another operation")
  13. // ErrTokenVersionMismatch 表示令牌版本与数据库当前版本不一致,刷新令牌失败。
  14. // 典型场景:refreshToken rotation 并发到达 —— 只有持有当前 tokenVersion 的那一次能原子递增成功,
  15. // 其余全部返回该错误,防止两个请求都"换到"新令牌(导致会话劫持)。
  16. var ErrTokenVersionMismatch = errors.New("token version mismatch")
  17. var _ SysUserModel = (*customSysUserModel)(nil)
  18. type (
  19. SysUserModel interface {
  20. sysUserModel
  21. FindListByPage(ctx context.Context, page, pageSize int64) ([]*SysUser, int64, error)
  22. FindListByProductMembers(ctx context.Context, productCode string, page, pageSize int64) ([]*SysUser, map[int64]string, int64, error)
  23. FindByIds(ctx context.Context, ids []int64) ([]*SysUser, error)
  24. FindIdsByDeptId(ctx context.Context, deptId int64) ([]int64, error)
  25. // UpdateProfile 更新用户资料字段(昵称 / 邮箱 / 手机 / 备注 / 部门 / 状态),username 仅用于
  26. // 构造旧缓存键 `cacheSysUserUsernamePrefix` 做失效,**不会**被写入 SET 子句。若未来确实需要
  27. // 修改 username,请独立实现 `UpdateUsernameTx`:
  28. // ① 同事务内做 old/new 两份 UsernameKey 的失效;
  29. // ② 捕获 1062 (UNIQUE 冲突) → response.ErrConflict,不得混进本方法的签名(审计 L-R11-3)。
  30. UpdateProfile(ctx context.Context, id int64, username string, nickname, email, phone, remark string, deptId, newStatus int64, statusChanged bool, expectedUpdateTime int64) error
  31. // UpdateProfileWithTx 与 UpdateProfile 的 SET 子句等价,但 UPDATE 执行在调用方传入的事务里;
  32. // 用于"改 deptId → 同事务内先 FOR SHARE sys_dept 目标行"修复 DeleteDept vs UpdateUser
  33. // 的 write skew(审计 M-R11-3)。
  34. // 审计 L-R12-1:本方法**不再**自己做缓存失效——go-zero `m.ExecCtx` 的 DelCache 会在闭包
  35. // 成功返回时立即发生,这里"成功"只代表 session.ExecCtx 收到了行数,事务仍在进行中,
  36. // 并发 FindOne 会把尚未提交的旧值重新灌回 sysUser 低层缓存造成 stale。调用方**必须**在
  37. // TransactCtx 返回(事务 commit)成功之后调用 InvalidateProfileCache(id, username) 做
  38. // post-commit 失效;同模式对任何 `*WithTx` 方法都成立。
  39. UpdateProfileWithTx(ctx context.Context, session sqlx.Session, id int64, username string, nickname, email, phone, remark string, deptId, newStatus int64, statusChanged bool, expectedUpdateTime int64) error
  40. // InvalidateProfileCache 失效 sysUser 的 id / username 两把低层缓存键。仅应在事务
  41. // commit 成功后由调用方显式调用(见 UpdateProfileWithTx 说明)。best-effort:失效失败
  42. // 只留日志,最终由 TTL 兜底,不影响事务已完成的业务结果(审计 L-R12-1)。
  43. InvalidateProfileCache(ctx context.Context, id int64, username string)
  44. // UpdatePassword 审计 H-R11-1:expectedUpdateTime 必须由调用方用**外层校验旧密码时拿到的
  45. // 那一份 updateTime** 显式透传;禁止函数内部再 FindOne 自对齐乐观锁,否则内层 CAS 等于退化
  46. // 为 last-write-wins,被"旧会话 + 知道旧密码"的攻击者可以把管理员紧急改过的新密码盖回。
  47. // username 仅用于构造 `cacheSysUserUsernamePrefix` 缓存键失效(sysUser.username 唯一,本函数
  48. // 不会修改它)。
  49. UpdatePassword(ctx context.Context, id int64, username string, password string, mustChangePassword, expectedUpdateTime int64) error
  50. // UpdateStatus 审计 M-R11-2:username 由调用方透传,避免仅为构造缓存键而多打一次 FindOne。
  51. UpdateStatus(ctx context.Context, id int64, username string, status int64, expectedUpdateTime int64) error
  52. // IncrementTokenVersion 审计 M-R11-2:username 由调用方透传,避免仅为构造缓存键而多打一次 FindOne。
  53. IncrementTokenVersion(ctx context.Context, id int64, username string) (int64, error)
  54. IncrementTokenVersionIfMatch(ctx context.Context, id int64, username string, expected int64) (int64, error)
  55. }
  56. customSysUserModel struct {
  57. *defaultSysUserModel
  58. }
  59. )
  60. func NewSysUserModel(conn sqlx.SqlConn, c cache.CacheConf, cachePrefix string, opts ...cache.Option) SysUserModel {
  61. return &customSysUserModel{
  62. defaultSysUserModel: newSysUserModel(conn, c, cachePrefix, opts...),
  63. }
  64. }
  65. func (m *customSysUserModel) FindListByPage(ctx context.Context, page, pageSize int64) ([]*SysUser, int64, error) {
  66. var total int64
  67. countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s", m.table)
  68. if err := m.QueryRowNoCacheCtx(ctx, &total, countQuery); err != nil {
  69. return nil, 0, err
  70. }
  71. var list []*SysUser
  72. query := fmt.Sprintf("SELECT %s FROM %s ORDER BY id DESC LIMIT ?,?", sysUserRows, m.table)
  73. if err := m.QueryRowsNoCacheCtx(ctx, &list, query, (page-1)*pageSize, pageSize); err != nil {
  74. return nil, 0, err
  75. }
  76. return list, total, nil
  77. }
  78. type UserWithMemberType struct {
  79. SysUser
  80. MemberType string `db:"memberType"`
  81. }
  82. func (m *customSysUserModel) FindListByProductMembers(ctx context.Context, productCode string, page, pageSize int64) ([]*SysUser, map[int64]string, int64, error) {
  83. memberTable := "`sys_product_member`"
  84. var total int64
  85. countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s u INNER JOIN %s pm ON u.`id` = pm.`userId` WHERE pm.`productCode` = ?", m.table, memberTable)
  86. if err := m.QueryRowNoCacheCtx(ctx, &total, countQuery, productCode); err != nil {
  87. return nil, nil, 0, err
  88. }
  89. var list []*UserWithMemberType
  90. fields := strings.Join(sysUserFieldNames, ",u.")
  91. query := fmt.Sprintf("SELECT u.%s, pm.`memberType` FROM %s u INNER JOIN %s pm ON u.`id` = pm.`userId` WHERE pm.`productCode` = ? ORDER BY u.`id` DESC LIMIT ?,?", fields, m.table, memberTable)
  92. if err := m.QueryRowsNoCacheCtx(ctx, &list, query, productCode, (page-1)*pageSize, pageSize); err != nil {
  93. return nil, nil, 0, err
  94. }
  95. users := make([]*SysUser, len(list))
  96. memberMap := make(map[int64]string, len(list))
  97. for i, item := range list {
  98. users[i] = &item.SysUser
  99. memberMap[item.Id] = item.MemberType
  100. }
  101. return users, memberMap, total, nil
  102. }
  103. func (m *customSysUserModel) FindIdsByDeptId(ctx context.Context, deptId int64) ([]int64, error) {
  104. var ids []int64
  105. query := fmt.Sprintf("SELECT `id` FROM %s WHERE `deptId` = ?", m.table)
  106. if err := m.QueryRowsNoCacheCtx(ctx, &ids, query, deptId); err != nil {
  107. return nil, err
  108. }
  109. return ids, nil
  110. }
  111. func (m *customSysUserModel) UpdateProfile(ctx context.Context, id int64, username string, nickname, email, phone, remark string, deptId, newStatus int64, statusChanged bool, expectedUpdateTime int64) error {
  112. sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id)
  113. sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, username)
  114. now := time.Now().Unix()
  115. res, err := m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (sql.Result, error) {
  116. if statusChanged {
  117. query := fmt.Sprintf("UPDATE %s SET `nickname`=?, `email`=?, `phone`=?, `remark`=?, `deptId`=?, `status`=?, `tokenVersion`=`tokenVersion`+1, `updateTime`=? WHERE `id`=? AND `updateTime`=?", m.table)
  118. return conn.ExecCtx(ctx, query, nickname, email, phone, remark, deptId, newStatus, now, id, expectedUpdateTime)
  119. }
  120. query := fmt.Sprintf("UPDATE %s SET `nickname`=?, `email`=?, `phone`=?, `remark`=?, `deptId`=?, `updateTime`=? WHERE `id`=? AND `updateTime`=?", m.table)
  121. return conn.ExecCtx(ctx, query, nickname, email, phone, remark, deptId, now, id, expectedUpdateTime)
  122. }, sysUserIdKey, sysUserUsernameKey)
  123. if err != nil {
  124. return err
  125. }
  126. affected, _ := res.RowsAffected()
  127. if affected == 0 {
  128. return ErrUpdateConflict
  129. }
  130. return nil
  131. }
  132. // UpdateProfileWithTx 见接口注释(审计 M-R11-3 + L-R12-1)。
  133. // 实现上**绕过** m.ExecCtx 的 pre-commit DelCache 语义——仅调用 session.ExecCtx,缓存失效由
  134. // 调用方在事务 commit 成功后显式走 InvalidateProfileCache。
  135. // session==nil 时直接拒绝(非事务场景必须改走 UpdateProfile)。
  136. func (m *customSysUserModel) UpdateProfileWithTx(ctx context.Context, session sqlx.Session, id int64, username string, nickname, email, phone, remark string, deptId, newStatus int64, statusChanged bool, expectedUpdateTime int64) error {
  137. if session == nil {
  138. return errors.New("UpdateProfileWithTx requires a non-nil session")
  139. }
  140. _ = username // 保留形参以维持调用方契约一致,失效由 InvalidateProfileCache 另行处理
  141. now := time.Now().Unix()
  142. var (
  143. res sql.Result
  144. err error
  145. )
  146. if statusChanged {
  147. query := fmt.Sprintf("UPDATE %s SET `nickname`=?, `email`=?, `phone`=?, `remark`=?, `deptId`=?, `status`=?, `tokenVersion`=`tokenVersion`+1, `updateTime`=? WHERE `id`=? AND `updateTime`=?", m.table)
  148. res, err = session.ExecCtx(ctx, query, nickname, email, phone, remark, deptId, newStatus, now, id, expectedUpdateTime)
  149. } else {
  150. query := fmt.Sprintf("UPDATE %s SET `nickname`=?, `email`=?, `phone`=?, `remark`=?, `deptId`=?, `updateTime`=? WHERE `id`=? AND `updateTime`=?", m.table)
  151. res, err = session.ExecCtx(ctx, query, nickname, email, phone, remark, deptId, now, id, expectedUpdateTime)
  152. }
  153. if err != nil {
  154. return err
  155. }
  156. affected, _ := res.RowsAffected()
  157. if affected == 0 {
  158. return ErrUpdateConflict
  159. }
  160. return nil
  161. }
  162. // InvalidateProfileCache 见接口注释(审计 L-R12-1)。这里不暴露错误——post-commit 阶段
  163. // DB 已为权威,缓存清理失败只会让 sysUser 低层缓存继续提供旧值直至 TTL 过期,而 Logic 层
  164. // 紧接其后的 UserDetailsLoader.Clean 已经失效了上层 UserDetails 聚合缓存;两级缓存一起过期
  165. // 的最坏情形仍然等价于 loadUser 的 cache-miss 路径,不会放大故障面。
  166. func (m *customSysUserModel) InvalidateProfileCache(ctx context.Context, id int64, username string) {
  167. sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id)
  168. sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, username)
  169. _ = m.DelCacheCtx(ctx, sysUserIdKey, sysUserUsernameKey)
  170. }
  171. func (m *customSysUserModel) UpdatePassword(ctx context.Context, id int64, username string, password string, mustChangePassword, expectedUpdateTime int64) error {
  172. sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id)
  173. sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, username)
  174. // 审计 H-R11-1:expectedUpdateTime 必须来自**外层校验旧密码时读到的**快照。不要再内部 FindOne
  175. // 自取 data.UpdateTime —— 那会让 CAS 在本函数内自我对齐,退化为 last-write-wins,让"会话持有 +
  176. // 知道旧密码"的攻击者可以把 admin 紧急改过的新密码盖回到自己手里那一份。
  177. res, err := m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (sql.Result, error) {
  178. query := fmt.Sprintf("UPDATE %s SET `password` = ?, `mustChangePassword` = ?, `tokenVersion` = `tokenVersion` + 1, `updateTime` = ? WHERE `id` = ? AND `updateTime` = ?", m.table)
  179. return conn.ExecCtx(ctx, query, password, mustChangePassword, time.Now().Unix(), id, expectedUpdateTime)
  180. }, sysUserIdKey, sysUserUsernameKey)
  181. if err != nil {
  182. return err
  183. }
  184. if affected, _ := res.RowsAffected(); affected == 0 {
  185. // 行被删除或被并发改过(任何字段,包括 status/password/profile):统一回 ErrUpdateConflict。
  186. return ErrUpdateConflict
  187. }
  188. return nil
  189. }
  190. // UpdateStatus 修改用户 status,并强制递增 tokenVersion 让已签发令牌失效。
  191. // 审计 L-N4:WHERE 必须带 `updateTime=?` 乐观锁,和 UpdateProfile / UpdatePassword 语义对齐。
  192. // 上游 UpdateUserStatusLogic 已经从 ValidateStatusChange 拿到 sysUser,调用方应把
  193. // `sysUser.UpdateTime` 当作 expectedUpdateTime 传入:
  194. // - expectedUpdateTime 不匹配 → ErrUpdateConflict;上层统一回 409 "数据已被其他操作修改"。
  195. // - 避免并发冻结/解冻请求走"last-write-wins",出现两个 admin 同时点"冻结"/"解冻"
  196. // 时后到者覆盖先到者、tokenVersion 被连续加两次把刚刚解冻的用户再次踢下线的诡异现象。
  197. func (m *customSysUserModel) UpdateStatus(ctx context.Context, id int64, username string, status int64, expectedUpdateTime int64) error {
  198. // 审计 M-R11-2:username 由调用方(ValidateStatusChange 返回的目标用户对象)显式透传,
  199. // 不再内部 FindOne。真实并发安全继续靠 `WHERE updateTime = expectedUpdateTime` 乐观锁兜底。
  200. sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id)
  201. sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, username)
  202. res, err := m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (sql.Result, error) {
  203. query := fmt.Sprintf("UPDATE %s SET `status` = ?, `tokenVersion` = `tokenVersion` + 1, `updateTime` = ? WHERE `id` = ? AND `updateTime` = ?", m.table)
  204. return conn.ExecCtx(ctx, query, status, time.Now().Unix(), id, expectedUpdateTime)
  205. }, sysUserIdKey, sysUserUsernameKey)
  206. if err != nil {
  207. return err
  208. }
  209. if affected, _ := res.RowsAffected(); affected == 0 {
  210. // 行被删除或被并发改过:对外统一回 ErrUpdateConflict,避免对已删除 / 被并发改过的行
  211. // 返回 nil 让上层误判为"冻结生效"(审计 M-2 / L-N4)。
  212. return ErrUpdateConflict
  213. }
  214. return nil
  215. }
  216. // IncrementTokenVersion 强制递增当前用户的 tokenVersion,让**所有**已签发的 access/refresh 立即失效。
  217. //
  218. // WARN: 仅限"强制全量会话失效"场景调用——主动 Logout 或封禁/重置密码。Refresh / Rotate 场景
  219. // 必须调用 IncrementTokenVersionIfMatch 走 CAS 语义,否则会回到 R5 以前的并发 rotate 窗口,
  220. // 两次并发 refresh 都能换到新令牌,等同于会话劫持(见审计 L-2)。
  221. // 调用前请先走 TokenOpLimiter 等限流,避免被反复触发把合法用户 kick 出登录。
  222. func (m *customSysUserModel) IncrementTokenVersion(ctx context.Context, id int64, username string) (int64, error) {
  223. // 审计 M-R11-2:username 由调用方显式透传,不再内部 FindOne 仅为构造缓存键。
  224. // 调用方(Logout 唯一入口)在 UserDetailsLoader.Load 之后天然已经持有 ud.Username。
  225. sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id)
  226. sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, username)
  227. var newVersion int64
  228. err := m.TransactCtx(ctx, func(ctx context.Context, session sqlx.Session) error {
  229. query := fmt.Sprintf("UPDATE %s SET `tokenVersion` = LAST_INSERT_ID(`tokenVersion` + 1), `updateTime` = ? WHERE `id` = ?", m.table)
  230. res, err := session.ExecCtx(ctx, query, time.Now().Unix(), id)
  231. if err != nil {
  232. return err
  233. }
  234. // 审计 L-R10-3:FindOne 与本次 UPDATE 之间若有并发 DELETE,affected=0 仍会返回 nil 让
  235. // SELECT LAST_INSERT_ID() 在全新事务里读出 0,调用方无法区分"真的递增成功"和"目标已被删除"。
  236. // 对外统一回 ErrUpdateConflict,调用方据此做日志/告警或跳过后续 Clean,避免对 tokenVersion=0
  237. // 之类的伪返回值做错误假设(如"版本从 1 起递增"的默认契约踩坑)。
  238. if affected, _ := res.RowsAffected(); affected == 0 {
  239. return ErrUpdateConflict
  240. }
  241. return session.QueryRowCtx(ctx, &newVersion, "SELECT LAST_INSERT_ID()")
  242. })
  243. if err != nil {
  244. return 0, err
  245. }
  246. _ = m.DelCacheCtx(ctx, sysUserIdKey, sysUserUsernameKey)
  247. return newVersion, nil
  248. }
  249. // IncrementTokenVersionIfMatch 原子递增 tokenVersion;仅当 DB 里当前 tokenVersion == expected 时才会生效。
  250. // 这是 refreshToken rotation 的原子 CAS:两个并发的刷新请求只有一个能命中 WHERE tokenVersion=expected,
  251. // 另一个 affected=0 返回 ErrTokenVersionMismatch,从而避免"两边都换到新令牌"的会话劫持窗口。
  252. //
  253. // 由上游透传 username 以便构造 cacheSysUserUsernamePrefix 的缓存键进行失效,避免为此多查一次 FindOne
  254. // (见审计 M-8)。上游通常已经通过 UserDetailsLoader.Load 拿到 username,零额外成本。
  255. func (m *customSysUserModel) IncrementTokenVersionIfMatch(ctx context.Context, id int64, username string, expected int64) (int64, error) {
  256. sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id)
  257. sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, username)
  258. var newVersion int64
  259. err := m.TransactCtx(ctx, func(ctx context.Context, session sqlx.Session) error {
  260. query := fmt.Sprintf("UPDATE %s SET `tokenVersion` = LAST_INSERT_ID(`tokenVersion` + 1), `updateTime` = ? WHERE `id` = ? AND `tokenVersion` = ?", m.table)
  261. res, err := session.ExecCtx(ctx, query, time.Now().Unix(), id, expected)
  262. if err != nil {
  263. return err
  264. }
  265. affected, _ := res.RowsAffected()
  266. if affected == 0 {
  267. return ErrTokenVersionMismatch
  268. }
  269. return session.QueryRowCtx(ctx, &newVersion, "SELECT LAST_INSERT_ID()")
  270. })
  271. if err != nil {
  272. return 0, err
  273. }
  274. _ = m.DelCacheCtx(ctx, sysUserIdKey, sysUserUsernameKey)
  275. return newVersion, nil
  276. }
  277. func (m *customSysUserModel) FindByIds(ctx context.Context, ids []int64) ([]*SysUser, error) {
  278. if len(ids) == 0 {
  279. return nil, nil
  280. }
  281. placeholders := make([]string, len(ids))
  282. args := make([]interface{}, len(ids))
  283. for i, id := range ids {
  284. placeholders[i] = "?"
  285. args[i] = id
  286. }
  287. var list []*SysUser
  288. query := fmt.Sprintf("SELECT %s FROM %s WHERE `id` IN (%s)", sysUserRows, m.table, strings.Join(placeholders, ","))
  289. if err := m.QueryRowsNoCacheCtx(ctx, &list, query, args...); err != nil {
  290. return nil, err
  291. }
  292. return list, nil
  293. }