package user import ( "context" "database/sql" "errors" "fmt" "strings" "time" "github.com/zeromicro/go-zero/core/stores/cache" "github.com/zeromicro/go-zero/core/stores/sqlx" ) var ErrUpdateConflict = errors.New("update conflict: data has been modified by another operation") // ErrTokenVersionMismatch 表示令牌版本与数据库当前版本不一致,刷新令牌失败。 // 典型场景:refreshToken rotation 并发到达 —— 只有持有当前 tokenVersion 的那一次能原子递增成功, // 其余全部返回该错误,防止两个请求都"换到"新令牌(导致会话劫持)。 var ErrTokenVersionMismatch = errors.New("token version mismatch") var _ SysUserModel = (*customSysUserModel)(nil) type ( SysUserModel interface { sysUserModel FindListByPage(ctx context.Context, page, pageSize int64) ([]*SysUser, int64, error) FindListByProductMembers(ctx context.Context, productCode string, page, pageSize int64) ([]*SysUser, map[int64]string, int64, error) FindByIds(ctx context.Context, ids []int64) ([]*SysUser, error) FindIdsByDeptId(ctx context.Context, deptId int64) ([]int64, error) // UpdateProfile 更新用户资料字段(昵称 / 邮箱 / 手机 / 备注 / 部门 / 状态),username 仅用于 // 构造旧缓存键 `cacheSysUserUsernamePrefix` 做失效,**不会**被写入 SET 子句。若未来确实需要 // 修改 username,请独立实现 `UpdateUsernameTx`: // ① 同事务内做 old/new 两份 UsernameKey 的失效; // ② 捕获 1062 (UNIQUE 冲突) → response.ErrConflict,不得混进本方法的签名(审计 L-R11-3)。 UpdateProfile(ctx context.Context, id int64, username string, nickname, email, phone, remark string, deptId, newStatus int64, statusChanged bool, expectedUpdateTime int64) error // UpdateProfileWithTx 与 UpdateProfile 行为等价,但 UPDATE 执行在调用方传入的事务里; // 用于"改 deptId → 同事务内先 FOR SHARE sys_dept 目标行"修复 DeleteDept vs UpdateUser // 的 write skew(审计 M-R11-3)。缓存失效仍由 m.ExecCtx 按 (idKey, usernameKey) 兜底。 UpdateProfileWithTx(ctx context.Context, session sqlx.Session, id int64, username string, nickname, email, phone, remark string, deptId, newStatus int64, statusChanged bool, expectedUpdateTime int64) error // UpdatePassword 审计 H-R11-1:expectedUpdateTime 必须由调用方用**外层校验旧密码时拿到的 // 那一份 updateTime** 显式透传;禁止函数内部再 FindOne 自对齐乐观锁,否则内层 CAS 等于退化 // 为 last-write-wins,被"旧会话 + 知道旧密码"的攻击者可以把管理员紧急改过的新密码盖回。 // username 仅用于构造 `cacheSysUserUsernamePrefix` 缓存键失效(sysUser.username 唯一,本函数 // 不会修改它)。 UpdatePassword(ctx context.Context, id int64, username string, password string, mustChangePassword, expectedUpdateTime int64) error // UpdateStatus 审计 M-R11-2:username 由调用方透传,避免仅为构造缓存键而多打一次 FindOne。 UpdateStatus(ctx context.Context, id int64, username string, status int64, expectedUpdateTime int64) error // IncrementTokenVersion 审计 M-R11-2:username 由调用方透传,避免仅为构造缓存键而多打一次 FindOne。 IncrementTokenVersion(ctx context.Context, id int64, username string) (int64, error) IncrementTokenVersionIfMatch(ctx context.Context, id int64, username string, expected int64) (int64, error) } customSysUserModel struct { *defaultSysUserModel } ) func NewSysUserModel(conn sqlx.SqlConn, c cache.CacheConf, cachePrefix string, opts ...cache.Option) SysUserModel { return &customSysUserModel{ defaultSysUserModel: newSysUserModel(conn, c, cachePrefix, opts...), } } func (m *customSysUserModel) FindListByPage(ctx context.Context, page, pageSize int64) ([]*SysUser, int64, error) { var total int64 countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s", m.table) if err := m.QueryRowNoCacheCtx(ctx, &total, countQuery); err != nil { return nil, 0, err } var list []*SysUser query := fmt.Sprintf("SELECT %s FROM %s ORDER BY id DESC LIMIT ?,?", sysUserRows, m.table) if err := m.QueryRowsNoCacheCtx(ctx, &list, query, (page-1)*pageSize, pageSize); err != nil { return nil, 0, err } return list, total, nil } type UserWithMemberType struct { SysUser MemberType string `db:"memberType"` } func (m *customSysUserModel) FindListByProductMembers(ctx context.Context, productCode string, page, pageSize int64) ([]*SysUser, map[int64]string, int64, error) { memberTable := "`sys_product_member`" var total int64 countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s u INNER JOIN %s pm ON u.`id` = pm.`userId` WHERE pm.`productCode` = ?", m.table, memberTable) if err := m.QueryRowNoCacheCtx(ctx, &total, countQuery, productCode); err != nil { return nil, nil, 0, err } var list []*UserWithMemberType fields := strings.Join(sysUserFieldNames, ",u.") query := fmt.Sprintf("SELECT u.%s, pm.`memberType` FROM %s u INNER JOIN %s pm ON u.`id` = pm.`userId` WHERE pm.`productCode` = ? ORDER BY u.`id` DESC LIMIT ?,?", fields, m.table, memberTable) if err := m.QueryRowsNoCacheCtx(ctx, &list, query, productCode, (page-1)*pageSize, pageSize); err != nil { return nil, nil, 0, err } users := make([]*SysUser, len(list)) memberMap := make(map[int64]string, len(list)) for i, item := range list { users[i] = &item.SysUser memberMap[item.Id] = item.MemberType } return users, memberMap, total, nil } func (m *customSysUserModel) FindIdsByDeptId(ctx context.Context, deptId int64) ([]int64, error) { var ids []int64 query := fmt.Sprintf("SELECT `id` FROM %s WHERE `deptId` = ?", m.table) if err := m.QueryRowsNoCacheCtx(ctx, &ids, query, deptId); err != nil { return nil, err } return ids, nil } func (m *customSysUserModel) UpdateProfile(ctx context.Context, id int64, username string, nickname, email, phone, remark string, deptId, newStatus int64, statusChanged bool, expectedUpdateTime int64) error { sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id) sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, username) now := time.Now().Unix() res, err := m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (sql.Result, error) { if statusChanged { query := fmt.Sprintf("UPDATE %s SET `nickname`=?, `email`=?, `phone`=?, `remark`=?, `deptId`=?, `status`=?, `tokenVersion`=`tokenVersion`+1, `updateTime`=? WHERE `id`=? AND `updateTime`=?", m.table) return conn.ExecCtx(ctx, query, nickname, email, phone, remark, deptId, newStatus, now, id, expectedUpdateTime) } query := fmt.Sprintf("UPDATE %s SET `nickname`=?, `email`=?, `phone`=?, `remark`=?, `deptId`=?, `updateTime`=? WHERE `id`=? AND `updateTime`=?", m.table) return conn.ExecCtx(ctx, query, nickname, email, phone, remark, deptId, now, id, expectedUpdateTime) }, sysUserIdKey, sysUserUsernameKey) if err != nil { return err } affected, _ := res.RowsAffected() if affected == 0 { return ErrUpdateConflict } return nil } // UpdateProfileWithTx 见接口注释(审计 M-R11-3)。实现上复用 m.ExecCtx 负责的 cache 失效语义, // 但 UPDATE 语句在调用方传入的 session 事务里执行。session==nil 时 panic,阻止误用—— // 非事务场景必须走 UpdateProfile 而不是本方法。 func (m *customSysUserModel) UpdateProfileWithTx(ctx context.Context, session sqlx.Session, id int64, username string, nickname, email, phone, remark string, deptId, newStatus int64, statusChanged bool, expectedUpdateTime int64) error { if session == nil { return errors.New("UpdateProfileWithTx requires a non-nil session") } sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id) sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, username) now := time.Now().Unix() res, err := m.ExecCtx(ctx, func(ctx context.Context, _ sqlx.SqlConn) (sql.Result, error) { if statusChanged { query := fmt.Sprintf("UPDATE %s SET `nickname`=?, `email`=?, `phone`=?, `remark`=?, `deptId`=?, `status`=?, `tokenVersion`=`tokenVersion`+1, `updateTime`=? WHERE `id`=? AND `updateTime`=?", m.table) return session.ExecCtx(ctx, query, nickname, email, phone, remark, deptId, newStatus, now, id, expectedUpdateTime) } query := fmt.Sprintf("UPDATE %s SET `nickname`=?, `email`=?, `phone`=?, `remark`=?, `deptId`=?, `updateTime`=? WHERE `id`=? AND `updateTime`=?", m.table) return session.ExecCtx(ctx, query, nickname, email, phone, remark, deptId, now, id, expectedUpdateTime) }, sysUserIdKey, sysUserUsernameKey) if err != nil { return err } affected, _ := res.RowsAffected() if affected == 0 { return ErrUpdateConflict } return nil } func (m *customSysUserModel) UpdatePassword(ctx context.Context, id int64, username string, password string, mustChangePassword, expectedUpdateTime int64) error { sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id) sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, username) // 审计 H-R11-1:expectedUpdateTime 必须来自**外层校验旧密码时读到的**快照。不要再内部 FindOne // 自取 data.UpdateTime —— 那会让 CAS 在本函数内自我对齐,退化为 last-write-wins,让"会话持有 + // 知道旧密码"的攻击者可以把 admin 紧急改过的新密码盖回到自己手里那一份。 res, err := m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (sql.Result, error) { query := fmt.Sprintf("UPDATE %s SET `password` = ?, `mustChangePassword` = ?, `tokenVersion` = `tokenVersion` + 1, `updateTime` = ? WHERE `id` = ? AND `updateTime` = ?", m.table) return conn.ExecCtx(ctx, query, password, mustChangePassword, time.Now().Unix(), id, expectedUpdateTime) }, sysUserIdKey, sysUserUsernameKey) if err != nil { return err } if affected, _ := res.RowsAffected(); affected == 0 { // 行被删除或被并发改过(任何字段,包括 status/password/profile):统一回 ErrUpdateConflict。 return ErrUpdateConflict } return nil } // UpdateStatus 修改用户 status,并强制递增 tokenVersion 让已签发令牌失效。 // 审计 L-N4:WHERE 必须带 `updateTime=?` 乐观锁,和 UpdateProfile / UpdatePassword 语义对齐。 // 上游 UpdateUserStatusLogic 已经从 ValidateStatusChange 拿到 sysUser,调用方应把 // `sysUser.UpdateTime` 当作 expectedUpdateTime 传入: // - expectedUpdateTime 不匹配 → ErrUpdateConflict;上层统一回 409 "数据已被其他操作修改"。 // - 避免并发冻结/解冻请求走"last-write-wins",出现两个 admin 同时点"冻结"/"解冻" // 时后到者覆盖先到者、tokenVersion 被连续加两次把刚刚解冻的用户再次踢下线的诡异现象。 func (m *customSysUserModel) UpdateStatus(ctx context.Context, id int64, username string, status int64, expectedUpdateTime int64) error { // 审计 M-R11-2:username 由调用方(ValidateStatusChange 返回的目标用户对象)显式透传, // 不再内部 FindOne。真实并发安全继续靠 `WHERE updateTime = expectedUpdateTime` 乐观锁兜底。 sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id) sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, username) res, err := m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (sql.Result, error) { query := fmt.Sprintf("UPDATE %s SET `status` = ?, `tokenVersion` = `tokenVersion` + 1, `updateTime` = ? WHERE `id` = ? AND `updateTime` = ?", m.table) return conn.ExecCtx(ctx, query, status, time.Now().Unix(), id, expectedUpdateTime) }, sysUserIdKey, sysUserUsernameKey) if err != nil { return err } if affected, _ := res.RowsAffected(); affected == 0 { // 行被删除或被并发改过:对外统一回 ErrUpdateConflict,避免对已删除 / 被并发改过的行 // 返回 nil 让上层误判为"冻结生效"(审计 M-2 / L-N4)。 return ErrUpdateConflict } return nil } // IncrementTokenVersion 强制递增当前用户的 tokenVersion,让**所有**已签发的 access/refresh 立即失效。 // // WARN: 仅限"强制全量会话失效"场景调用——主动 Logout 或封禁/重置密码。Refresh / Rotate 场景 // 必须调用 IncrementTokenVersionIfMatch 走 CAS 语义,否则会回到 R5 以前的并发 rotate 窗口, // 两次并发 refresh 都能换到新令牌,等同于会话劫持(见审计 L-2)。 // 调用前请先走 TokenOpLimiter 等限流,避免被反复触发把合法用户 kick 出登录。 func (m *customSysUserModel) IncrementTokenVersion(ctx context.Context, id int64, username string) (int64, error) { // 审计 M-R11-2:username 由调用方显式透传,不再内部 FindOne 仅为构造缓存键。 // 调用方(Logout 唯一入口)在 UserDetailsLoader.Load 之后天然已经持有 ud.Username。 sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id) sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, username) var newVersion int64 err := m.TransactCtx(ctx, func(ctx context.Context, session sqlx.Session) error { query := fmt.Sprintf("UPDATE %s SET `tokenVersion` = LAST_INSERT_ID(`tokenVersion` + 1), `updateTime` = ? WHERE `id` = ?", m.table) res, err := session.ExecCtx(ctx, query, time.Now().Unix(), id) if err != nil { return err } // 审计 L-R10-3:FindOne 与本次 UPDATE 之间若有并发 DELETE,affected=0 仍会返回 nil 让 // SELECT LAST_INSERT_ID() 在全新事务里读出 0,调用方无法区分"真的递增成功"和"目标已被删除"。 // 对外统一回 ErrUpdateConflict,调用方据此做日志/告警或跳过后续 Clean,避免对 tokenVersion=0 // 之类的伪返回值做错误假设(如"版本从 1 起递增"的默认契约踩坑)。 if affected, _ := res.RowsAffected(); affected == 0 { return ErrUpdateConflict } return session.QueryRowCtx(ctx, &newVersion, "SELECT LAST_INSERT_ID()") }) if err != nil { return 0, err } _ = m.DelCacheCtx(ctx, sysUserIdKey, sysUserUsernameKey) return newVersion, nil } // IncrementTokenVersionIfMatch 原子递增 tokenVersion;仅当 DB 里当前 tokenVersion == expected 时才会生效。 // 这是 refreshToken rotation 的原子 CAS:两个并发的刷新请求只有一个能命中 WHERE tokenVersion=expected, // 另一个 affected=0 返回 ErrTokenVersionMismatch,从而避免"两边都换到新令牌"的会话劫持窗口。 // // 由上游透传 username 以便构造 cacheSysUserUsernamePrefix 的缓存键进行失效,避免为此多查一次 FindOne // (见审计 M-8)。上游通常已经通过 UserDetailsLoader.Load 拿到 username,零额外成本。 func (m *customSysUserModel) IncrementTokenVersionIfMatch(ctx context.Context, id int64, username string, expected int64) (int64, error) { sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id) sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, username) var newVersion int64 err := m.TransactCtx(ctx, func(ctx context.Context, session sqlx.Session) error { query := fmt.Sprintf("UPDATE %s SET `tokenVersion` = LAST_INSERT_ID(`tokenVersion` + 1), `updateTime` = ? WHERE `id` = ? AND `tokenVersion` = ?", m.table) res, err := session.ExecCtx(ctx, query, time.Now().Unix(), id, expected) if err != nil { return err } affected, _ := res.RowsAffected() if affected == 0 { return ErrTokenVersionMismatch } return session.QueryRowCtx(ctx, &newVersion, "SELECT LAST_INSERT_ID()") }) if err != nil { return 0, err } _ = m.DelCacheCtx(ctx, sysUserIdKey, sysUserUsernameKey) return newVersion, nil } func (m *customSysUserModel) FindByIds(ctx context.Context, ids []int64) ([]*SysUser, error) { if len(ids) == 0 { return nil, nil } placeholders := make([]string, len(ids)) args := make([]interface{}, len(ids)) for i, id := range ids { placeholders[i] = "?" args[i] = id } var list []*SysUser query := fmt.Sprintf("SELECT %s FROM %s WHERE `id` IN (%s)", sysUserRows, m.table, strings.Join(placeholders, ",")) if err := m.QueryRowsNoCacheCtx(ctx, &list, query, args...); err != nil { return nil, err } return list, nil }