sysUserModel.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. package user
  2. import (
  3. "context"
  4. "database/sql"
  5. "errors"
  6. "fmt"
  7. "strings"
  8. "time"
  9. "github.com/zeromicro/go-zero/core/stores/cache"
  10. "github.com/zeromicro/go-zero/core/stores/sqlx"
  11. )
  12. var ErrUpdateConflict = errors.New("update conflict: data has been modified by another operation")
  13. // ErrTokenVersionMismatch 表示令牌版本与数据库当前版本不一致,刷新令牌失败。
  14. // 典型场景:refreshToken rotation 并发到达 —— 只有持有当前 tokenVersion 的那一次能原子递增成功,
  15. // 其余全部返回该错误,防止两个请求都"换到"新令牌(导致会话劫持)。
  16. var ErrTokenVersionMismatch = errors.New("token version mismatch")
  17. var _ SysUserModel = (*customSysUserModel)(nil)
  18. type (
  19. SysUserModel interface {
  20. sysUserModel
  21. FindListByPage(ctx context.Context, page, pageSize int64) ([]*SysUser, int64, error)
  22. FindListByProductMembers(ctx context.Context, productCode string, page, pageSize int64) ([]*SysUser, map[int64]string, int64, error)
  23. FindByIds(ctx context.Context, ids []int64) ([]*SysUser, error)
  24. FindIdsByDeptId(ctx context.Context, deptId int64) ([]int64, error)
  25. UpdateProfile(ctx context.Context, id int64, username string, nickname, email, phone, remark string, deptId, newStatus int64, statusChanged bool, expectedUpdateTime int64) error
  26. UpdatePassword(ctx context.Context, id int64, password string, mustChangePassword int64) error
  27. UpdateStatus(ctx context.Context, id int64, status int64, expectedUpdateTime int64) error
  28. IncrementTokenVersion(ctx context.Context, id int64) (int64, error)
  29. IncrementTokenVersionIfMatch(ctx context.Context, id int64, username string, expected int64) (int64, error)
  30. }
  31. customSysUserModel struct {
  32. *defaultSysUserModel
  33. }
  34. )
  35. func NewSysUserModel(conn sqlx.SqlConn, c cache.CacheConf, cachePrefix string, opts ...cache.Option) SysUserModel {
  36. return &customSysUserModel{
  37. defaultSysUserModel: newSysUserModel(conn, c, cachePrefix, opts...),
  38. }
  39. }
  40. func (m *customSysUserModel) FindListByPage(ctx context.Context, page, pageSize int64) ([]*SysUser, int64, error) {
  41. var total int64
  42. countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s", m.table)
  43. if err := m.QueryRowNoCacheCtx(ctx, &total, countQuery); err != nil {
  44. return nil, 0, err
  45. }
  46. var list []*SysUser
  47. query := fmt.Sprintf("SELECT %s FROM %s ORDER BY id DESC LIMIT ?,?", sysUserRows, m.table)
  48. if err := m.QueryRowsNoCacheCtx(ctx, &list, query, (page-1)*pageSize, pageSize); err != nil {
  49. return nil, 0, err
  50. }
  51. return list, total, nil
  52. }
  53. type UserWithMemberType struct {
  54. SysUser
  55. MemberType string `db:"memberType"`
  56. }
  57. func (m *customSysUserModel) FindListByProductMembers(ctx context.Context, productCode string, page, pageSize int64) ([]*SysUser, map[int64]string, int64, error) {
  58. memberTable := "`sys_product_member`"
  59. var total int64
  60. countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s u INNER JOIN %s pm ON u.`id` = pm.`userId` WHERE pm.`productCode` = ?", m.table, memberTable)
  61. if err := m.QueryRowNoCacheCtx(ctx, &total, countQuery, productCode); err != nil {
  62. return nil, nil, 0, err
  63. }
  64. var list []*UserWithMemberType
  65. fields := strings.Join(sysUserFieldNames, ",u.")
  66. query := fmt.Sprintf("SELECT u.%s, pm.`memberType` FROM %s u INNER JOIN %s pm ON u.`id` = pm.`userId` WHERE pm.`productCode` = ? ORDER BY u.`id` DESC LIMIT ?,?", fields, m.table, memberTable)
  67. if err := m.QueryRowsNoCacheCtx(ctx, &list, query, productCode, (page-1)*pageSize, pageSize); err != nil {
  68. return nil, nil, 0, err
  69. }
  70. users := make([]*SysUser, len(list))
  71. memberMap := make(map[int64]string, len(list))
  72. for i, item := range list {
  73. users[i] = &item.SysUser
  74. memberMap[item.Id] = item.MemberType
  75. }
  76. return users, memberMap, total, nil
  77. }
  78. func (m *customSysUserModel) FindIdsByDeptId(ctx context.Context, deptId int64) ([]int64, error) {
  79. var ids []int64
  80. query := fmt.Sprintf("SELECT `id` FROM %s WHERE `deptId` = ?", m.table)
  81. if err := m.QueryRowsNoCacheCtx(ctx, &ids, query, deptId); err != nil {
  82. return nil, err
  83. }
  84. return ids, nil
  85. }
  86. func (m *customSysUserModel) UpdateProfile(ctx context.Context, id int64, username string, nickname, email, phone, remark string, deptId, newStatus int64, statusChanged bool, expectedUpdateTime int64) error {
  87. sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id)
  88. sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, username)
  89. now := time.Now().Unix()
  90. res, err := m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (sql.Result, error) {
  91. if statusChanged {
  92. query := fmt.Sprintf("UPDATE %s SET `nickname`=?, `email`=?, `phone`=?, `remark`=?, `deptId`=?, `status`=?, `tokenVersion`=`tokenVersion`+1, `updateTime`=? WHERE `id`=? AND `updateTime`=?", m.table)
  93. return conn.ExecCtx(ctx, query, nickname, email, phone, remark, deptId, newStatus, now, id, expectedUpdateTime)
  94. }
  95. query := fmt.Sprintf("UPDATE %s SET `nickname`=?, `email`=?, `phone`=?, `remark`=?, `deptId`=?, `updateTime`=? WHERE `id`=? AND `updateTime`=?", m.table)
  96. return conn.ExecCtx(ctx, query, nickname, email, phone, remark, deptId, now, id, expectedUpdateTime)
  97. }, sysUserIdKey, sysUserUsernameKey)
  98. if err != nil {
  99. return err
  100. }
  101. affected, _ := res.RowsAffected()
  102. if affected == 0 {
  103. return ErrUpdateConflict
  104. }
  105. return nil
  106. }
  107. func (m *customSysUserModel) UpdatePassword(ctx context.Context, id int64, password string, mustChangePassword int64) error {
  108. data, err := m.FindOne(ctx, id)
  109. if err != nil {
  110. return err
  111. }
  112. sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id)
  113. sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, data.Username)
  114. // 乐观锁:WHERE 叠加 updateTime 与 FindOne 拿到的一致。避免 FindOne → Exec 之间并发改密把
  115. // 本次写盖成"最后一写赢"、或目标行被删除后仍返回成功造成语义欺骗(见审计 M-2)。
  116. expectedUpdateTime := data.UpdateTime
  117. res, err := m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (sql.Result, error) {
  118. query := fmt.Sprintf("UPDATE %s SET `password` = ?, `mustChangePassword` = ?, `tokenVersion` = `tokenVersion` + 1, `updateTime` = ? WHERE `id` = ? AND `updateTime` = ?", m.table)
  119. return conn.ExecCtx(ctx, query, password, mustChangePassword, time.Now().Unix(), id, expectedUpdateTime)
  120. }, sysUserIdKey, sysUserUsernameKey)
  121. if err != nil {
  122. return err
  123. }
  124. if affected, _ := res.RowsAffected(); affected == 0 {
  125. // 行被删除或被并发改过:对外统一回 ErrUpdateConflict,避免对已删除用户返回 nil 让上层
  126. // 误判为"改密成功"(审计 M-2)。
  127. return ErrUpdateConflict
  128. }
  129. return nil
  130. }
  131. // UpdateStatus 修改用户 status,并强制递增 tokenVersion 让已签发令牌失效。
  132. // 审计 L-N4:WHERE 必须带 `updateTime=?` 乐观锁,和 UpdateProfile / UpdatePassword 语义对齐。
  133. // 上游 UpdateUserStatusLogic 已经从 ValidateStatusChange 拿到 sysUser,调用方应把
  134. // `sysUser.UpdateTime` 当作 expectedUpdateTime 传入:
  135. // - expectedUpdateTime 不匹配 → ErrUpdateConflict;上层统一回 409 "数据已被其他操作修改"。
  136. // - 避免并发冻结/解冻请求走"last-write-wins",出现两个 admin 同时点"冻结"/"解冻"
  137. // 时后到者覆盖先到者、tokenVersion 被连续加两次把刚刚解冻的用户再次踢下线的诡异现象。
  138. func (m *customSysUserModel) UpdateStatus(ctx context.Context, id int64, status int64, expectedUpdateTime int64) error {
  139. data, err := m.FindOne(ctx, id)
  140. if err != nil {
  141. return err
  142. }
  143. sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id)
  144. sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, data.Username)
  145. res, err := m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (sql.Result, error) {
  146. query := fmt.Sprintf("UPDATE %s SET `status` = ?, `tokenVersion` = `tokenVersion` + 1, `updateTime` = ? WHERE `id` = ? AND `updateTime` = ?", m.table)
  147. return conn.ExecCtx(ctx, query, status, time.Now().Unix(), id, expectedUpdateTime)
  148. }, sysUserIdKey, sysUserUsernameKey)
  149. if err != nil {
  150. return err
  151. }
  152. if affected, _ := res.RowsAffected(); affected == 0 {
  153. // 行被删除或被并发改过:对外统一回 ErrUpdateConflict,避免对已删除 / 被并发改过的行
  154. // 返回 nil 让上层误判为"冻结生效"(审计 M-2 / L-N4)。
  155. return ErrUpdateConflict
  156. }
  157. return nil
  158. }
  159. // IncrementTokenVersion 强制递增当前用户的 tokenVersion,让**所有**已签发的 access/refresh 立即失效。
  160. //
  161. // WARN: 仅限"强制全量会话失效"场景调用——主动 Logout 或封禁/重置密码。Refresh / Rotate 场景
  162. // 必须调用 IncrementTokenVersionIfMatch 走 CAS 语义,否则会回到 R5 以前的并发 rotate 窗口,
  163. // 两次并发 refresh 都能换到新令牌,等同于会话劫持(见审计 L-2)。
  164. // 调用前请先走 TokenOpLimiter 等限流,避免被反复触发把合法用户 kick 出登录。
  165. func (m *customSysUserModel) IncrementTokenVersion(ctx context.Context, id int64) (int64, error) {
  166. data, err := m.FindOne(ctx, id)
  167. if err != nil {
  168. return 0, err
  169. }
  170. sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id)
  171. sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, data.Username)
  172. var newVersion int64
  173. err = m.TransactCtx(ctx, func(ctx context.Context, session sqlx.Session) error {
  174. query := fmt.Sprintf("UPDATE %s SET `tokenVersion` = LAST_INSERT_ID(`tokenVersion` + 1), `updateTime` = ? WHERE `id` = ?", m.table)
  175. if _, err := session.ExecCtx(ctx, query, time.Now().Unix(), id); err != nil {
  176. return err
  177. }
  178. return session.QueryRowCtx(ctx, &newVersion, "SELECT LAST_INSERT_ID()")
  179. })
  180. if err != nil {
  181. return 0, err
  182. }
  183. _ = m.DelCacheCtx(ctx, sysUserIdKey, sysUserUsernameKey)
  184. return newVersion, nil
  185. }
  186. // IncrementTokenVersionIfMatch 原子递增 tokenVersion;仅当 DB 里当前 tokenVersion == expected 时才会生效。
  187. // 这是 refreshToken rotation 的原子 CAS:两个并发的刷新请求只有一个能命中 WHERE tokenVersion=expected,
  188. // 另一个 affected=0 返回 ErrTokenVersionMismatch,从而避免"两边都换到新令牌"的会话劫持窗口。
  189. //
  190. // 由上游透传 username 以便构造 cacheSysUserUsernamePrefix 的缓存键进行失效,避免为此多查一次 FindOne
  191. // (见审计 M-8)。上游通常已经通过 UserDetailsLoader.Load 拿到 username,零额外成本。
  192. func (m *customSysUserModel) IncrementTokenVersionIfMatch(ctx context.Context, id int64, username string, expected int64) (int64, error) {
  193. sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id)
  194. sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, username)
  195. var newVersion int64
  196. err := m.TransactCtx(ctx, func(ctx context.Context, session sqlx.Session) error {
  197. query := fmt.Sprintf("UPDATE %s SET `tokenVersion` = LAST_INSERT_ID(`tokenVersion` + 1), `updateTime` = ? WHERE `id` = ? AND `tokenVersion` = ?", m.table)
  198. res, err := session.ExecCtx(ctx, query, time.Now().Unix(), id, expected)
  199. if err != nil {
  200. return err
  201. }
  202. affected, _ := res.RowsAffected()
  203. if affected == 0 {
  204. return ErrTokenVersionMismatch
  205. }
  206. return session.QueryRowCtx(ctx, &newVersion, "SELECT LAST_INSERT_ID()")
  207. })
  208. if err != nil {
  209. return 0, err
  210. }
  211. _ = m.DelCacheCtx(ctx, sysUserIdKey, sysUserUsernameKey)
  212. return newVersion, nil
  213. }
  214. func (m *customSysUserModel) FindByIds(ctx context.Context, ids []int64) ([]*SysUser, error) {
  215. if len(ids) == 0 {
  216. return nil, nil
  217. }
  218. placeholders := make([]string, len(ids))
  219. args := make([]interface{}, len(ids))
  220. for i, id := range ids {
  221. placeholders[i] = "?"
  222. args[i] = id
  223. }
  224. var list []*SysUser
  225. query := fmt.Sprintf("SELECT %s FROM %s WHERE `id` IN (%s)", sysUserRows, m.table, strings.Join(placeholders, ","))
  226. if err := m.QueryRowsNoCacheCtx(ctx, &list, query, args...); err != nil {
  227. return nil, err
  228. }
  229. return list, nil
  230. }