| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255 |
- package user
- import (
- "context"
- "database/sql"
- "errors"
- "fmt"
- "strings"
- "time"
- "github.com/zeromicro/go-zero/core/stores/cache"
- "github.com/zeromicro/go-zero/core/stores/sqlx"
- )
- var ErrUpdateConflict = errors.New("update conflict: data has been modified by another operation")
- // ErrTokenVersionMismatch 表示令牌版本与数据库当前版本不一致,刷新令牌失败。
- // 典型场景:refreshToken rotation 并发到达 —— 只有持有当前 tokenVersion 的那一次能原子递增成功,
- // 其余全部返回该错误,防止两个请求都"换到"新令牌(导致会话劫持)。
- var ErrTokenVersionMismatch = errors.New("token version mismatch")
- var _ SysUserModel = (*customSysUserModel)(nil)
- type (
- SysUserModel interface {
- sysUserModel
- FindListByPage(ctx context.Context, page, pageSize int64) ([]*SysUser, int64, error)
- FindListByProductMembers(ctx context.Context, productCode string, page, pageSize int64) ([]*SysUser, map[int64]string, int64, error)
- FindByIds(ctx context.Context, ids []int64) ([]*SysUser, error)
- FindIdsByDeptId(ctx context.Context, deptId int64) ([]int64, error)
- UpdateProfile(ctx context.Context, id int64, username string, nickname, email, phone, remark string, deptId, newStatus int64, statusChanged bool, expectedUpdateTime int64) error
- UpdatePassword(ctx context.Context, id int64, password string, mustChangePassword int64) error
- UpdateStatus(ctx context.Context, id int64, status int64) error
- IncrementTokenVersion(ctx context.Context, id int64) (int64, error)
- IncrementTokenVersionIfMatch(ctx context.Context, id int64, username string, expected int64) (int64, error)
- }
- customSysUserModel struct {
- *defaultSysUserModel
- }
- )
- func NewSysUserModel(conn sqlx.SqlConn, c cache.CacheConf, cachePrefix string, opts ...cache.Option) SysUserModel {
- return &customSysUserModel{
- defaultSysUserModel: newSysUserModel(conn, c, cachePrefix, opts...),
- }
- }
- func (m *customSysUserModel) FindListByPage(ctx context.Context, page, pageSize int64) ([]*SysUser, int64, error) {
- var total int64
- countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s", m.table)
- if err := m.QueryRowNoCacheCtx(ctx, &total, countQuery); err != nil {
- return nil, 0, err
- }
- var list []*SysUser
- query := fmt.Sprintf("SELECT %s FROM %s ORDER BY id DESC LIMIT ?,?", sysUserRows, m.table)
- if err := m.QueryRowsNoCacheCtx(ctx, &list, query, (page-1)*pageSize, pageSize); err != nil {
- return nil, 0, err
- }
- return list, total, nil
- }
- type UserWithMemberType struct {
- SysUser
- MemberType string `db:"memberType"`
- }
- func (m *customSysUserModel) FindListByProductMembers(ctx context.Context, productCode string, page, pageSize int64) ([]*SysUser, map[int64]string, int64, error) {
- memberTable := "`sys_product_member`"
- var total int64
- countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s u INNER JOIN %s pm ON u.`id` = pm.`userId` WHERE pm.`productCode` = ?", m.table, memberTable)
- if err := m.QueryRowNoCacheCtx(ctx, &total, countQuery, productCode); err != nil {
- return nil, nil, 0, err
- }
- var list []*UserWithMemberType
- fields := strings.Join(sysUserFieldNames, ",u.")
- query := fmt.Sprintf("SELECT u.%s, pm.`memberType` FROM %s u INNER JOIN %s pm ON u.`id` = pm.`userId` WHERE pm.`productCode` = ? ORDER BY u.`id` DESC LIMIT ?,?", fields, m.table, memberTable)
- if err := m.QueryRowsNoCacheCtx(ctx, &list, query, productCode, (page-1)*pageSize, pageSize); err != nil {
- return nil, nil, 0, err
- }
- users := make([]*SysUser, len(list))
- memberMap := make(map[int64]string, len(list))
- for i, item := range list {
- users[i] = &item.SysUser
- memberMap[item.Id] = item.MemberType
- }
- return users, memberMap, total, nil
- }
- func (m *customSysUserModel) FindIdsByDeptId(ctx context.Context, deptId int64) ([]int64, error) {
- var ids []int64
- query := fmt.Sprintf("SELECT `id` FROM %s WHERE `deptId` = ?", m.table)
- if err := m.QueryRowsNoCacheCtx(ctx, &ids, query, deptId); err != nil {
- return nil, err
- }
- return ids, nil
- }
- func (m *customSysUserModel) UpdateProfile(ctx context.Context, id int64, username string, nickname, email, phone, remark string, deptId, newStatus int64, statusChanged bool, expectedUpdateTime int64) error {
- sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id)
- sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, username)
- now := time.Now().Unix()
- res, err := m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (sql.Result, error) {
- if statusChanged {
- query := fmt.Sprintf("UPDATE %s SET `nickname`=?, `email`=?, `phone`=?, `remark`=?, `deptId`=?, `status`=?, `tokenVersion`=`tokenVersion`+1, `updateTime`=? WHERE `id`=? AND `updateTime`=?", m.table)
- return conn.ExecCtx(ctx, query, nickname, email, phone, remark, deptId, newStatus, now, id, expectedUpdateTime)
- }
- query := fmt.Sprintf("UPDATE %s SET `nickname`=?, `email`=?, `phone`=?, `remark`=?, `deptId`=?, `updateTime`=? WHERE `id`=? AND `updateTime`=?", m.table)
- return conn.ExecCtx(ctx, query, nickname, email, phone, remark, deptId, now, id, expectedUpdateTime)
- }, sysUserIdKey, sysUserUsernameKey)
- if err != nil {
- return err
- }
- affected, _ := res.RowsAffected()
- if affected == 0 {
- return ErrUpdateConflict
- }
- return nil
- }
- func (m *customSysUserModel) UpdatePassword(ctx context.Context, id int64, password string, mustChangePassword int64) error {
- data, err := m.FindOne(ctx, id)
- if err != nil {
- return err
- }
- sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id)
- sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, data.Username)
- // 乐观锁:WHERE 叠加 updateTime 与 FindOne 拿到的一致。避免 FindOne → Exec 之间并发改密把
- // 本次写盖成"最后一写赢"、或目标行被删除后仍返回成功造成语义欺骗(见审计 M-2)。
- expectedUpdateTime := data.UpdateTime
- res, err := m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (sql.Result, error) {
- query := fmt.Sprintf("UPDATE %s SET `password` = ?, `mustChangePassword` = ?, `tokenVersion` = `tokenVersion` + 1, `updateTime` = ? WHERE `id` = ? AND `updateTime` = ?", m.table)
- return conn.ExecCtx(ctx, query, password, mustChangePassword, time.Now().Unix(), id, expectedUpdateTime)
- }, sysUserIdKey, sysUserUsernameKey)
- if err != nil {
- return err
- }
- if affected, _ := res.RowsAffected(); affected == 0 {
- // 行被删除或被并发改过:对外统一回 ErrUpdateConflict,避免对已删除用户返回 nil 让上层
- // 误判为"改密成功"(审计 M-2)。
- return ErrUpdateConflict
- }
- return nil
- }
- func (m *customSysUserModel) UpdateStatus(ctx context.Context, id int64, status int64) error {
- data, err := m.FindOne(ctx, id)
- if err != nil {
- return err
- }
- sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id)
- sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, data.Username)
- res, err := m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (sql.Result, error) {
- query := fmt.Sprintf("UPDATE %s SET `status` = ?, `tokenVersion` = `tokenVersion` + 1, `updateTime` = ? WHERE `id` = ?", m.table)
- return conn.ExecCtx(ctx, query, status, time.Now().Unix(), id)
- }, sysUserIdKey, sysUserUsernameKey)
- if err != nil {
- return err
- }
- if affected, _ := res.RowsAffected(); affected == 0 {
- // 目标用户在 FindOne 后被并发删除;返回 ErrUpdateConflict 让上层区分"冻结生效"与"目标已消失"
- // (审计 M-2)。
- return ErrUpdateConflict
- }
- return nil
- }
- // IncrementTokenVersion 强制递增当前用户的 tokenVersion,让**所有**已签发的 access/refresh 立即失效。
- //
- // WARN: 仅限"强制全量会话失效"场景调用——主动 Logout 或封禁/重置密码。Refresh / Rotate 场景
- // 必须调用 IncrementTokenVersionIfMatch 走 CAS 语义,否则会回到 R5 以前的并发 rotate 窗口,
- // 两次并发 refresh 都能换到新令牌,等同于会话劫持(见审计 L-2)。
- // 调用前请先走 TokenOpLimiter 等限流,避免被反复触发把合法用户 kick 出登录。
- func (m *customSysUserModel) IncrementTokenVersion(ctx context.Context, id int64) (int64, error) {
- data, err := m.FindOne(ctx, id)
- if err != nil {
- return 0, err
- }
- sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id)
- sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, data.Username)
- var newVersion int64
- err = m.TransactCtx(ctx, func(ctx context.Context, session sqlx.Session) error {
- query := fmt.Sprintf("UPDATE %s SET `tokenVersion` = LAST_INSERT_ID(`tokenVersion` + 1), `updateTime` = ? WHERE `id` = ?", m.table)
- if _, err := session.ExecCtx(ctx, query, time.Now().Unix(), id); err != nil {
- return err
- }
- return session.QueryRowCtx(ctx, &newVersion, "SELECT LAST_INSERT_ID()")
- })
- if err != nil {
- return 0, err
- }
- _ = m.DelCacheCtx(ctx, sysUserIdKey, sysUserUsernameKey)
- return newVersion, nil
- }
- // IncrementTokenVersionIfMatch 原子递增 tokenVersion;仅当 DB 里当前 tokenVersion == expected 时才会生效。
- // 这是 refreshToken rotation 的原子 CAS:两个并发的刷新请求只有一个能命中 WHERE tokenVersion=expected,
- // 另一个 affected=0 返回 ErrTokenVersionMismatch,从而避免"两边都换到新令牌"的会话劫持窗口。
- //
- // 由上游透传 username 以便构造 cacheSysUserUsernamePrefix 的缓存键进行失效,避免为此多查一次 FindOne
- // (见审计 M-8)。上游通常已经通过 UserDetailsLoader.Load 拿到 username,零额外成本。
- func (m *customSysUserModel) IncrementTokenVersionIfMatch(ctx context.Context, id int64, username string, expected int64) (int64, error) {
- sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id)
- sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, username)
- var newVersion int64
- err := m.TransactCtx(ctx, func(ctx context.Context, session sqlx.Session) error {
- query := fmt.Sprintf("UPDATE %s SET `tokenVersion` = LAST_INSERT_ID(`tokenVersion` + 1), `updateTime` = ? WHERE `id` = ? AND `tokenVersion` = ?", m.table)
- res, err := session.ExecCtx(ctx, query, time.Now().Unix(), id, expected)
- if err != nil {
- return err
- }
- affected, _ := res.RowsAffected()
- if affected == 0 {
- return ErrTokenVersionMismatch
- }
- return session.QueryRowCtx(ctx, &newVersion, "SELECT LAST_INSERT_ID()")
- })
- if err != nil {
- return 0, err
- }
- _ = m.DelCacheCtx(ctx, sysUserIdKey, sysUserUsernameKey)
- return newVersion, nil
- }
- func (m *customSysUserModel) FindByIds(ctx context.Context, ids []int64) ([]*SysUser, error) {
- if len(ids) == 0 {
- return nil, nil
- }
- placeholders := make([]string, len(ids))
- args := make([]interface{}, len(ids))
- for i, id := range ids {
- placeholders[i] = "?"
- args[i] = id
- }
- var list []*SysUser
- query := fmt.Sprintf("SELECT %s FROM %s WHERE `id` IN (%s)", sysUserRows, m.table, strings.Join(placeholders, ","))
- if err := m.QueryRowsNoCacheCtx(ctx, &list, query, args...); err != nil {
- return nil, err
- }
- return list, nil
- }
|