package user import ( "context" "database/sql" "errors" "fmt" "strings" "time" "github.com/zeromicro/go-zero/core/stores/cache" "github.com/zeromicro/go-zero/core/stores/sqlx" ) var ErrUpdateConflict = errors.New("update conflict: data has been modified by another operation") // ErrTokenVersionMismatch 表示令牌版本与数据库当前版本不一致,刷新令牌失败。 // 典型场景:refreshToken rotation 并发到达 —— 只有持有当前 tokenVersion 的那一次能原子递增成功, // 其余全部返回该错误,防止两个请求都"换到"新令牌(导致会话劫持)。 var ErrTokenVersionMismatch = errors.New("token version mismatch") var _ SysUserModel = (*customSysUserModel)(nil) type ( SysUserModel interface { sysUserModel FindListByPage(ctx context.Context, page, pageSize int64) ([]*SysUser, int64, error) FindListByProductMembers(ctx context.Context, productCode string, page, pageSize int64) ([]*SysUser, map[int64]string, int64, error) FindByIds(ctx context.Context, ids []int64) ([]*SysUser, error) FindIdsByDeptId(ctx context.Context, deptId int64) ([]int64, error) UpdateProfile(ctx context.Context, id int64, username string, nickname, email, phone, remark string, deptId, newStatus int64, statusChanged bool, expectedUpdateTime int64) error UpdatePassword(ctx context.Context, id int64, password string, mustChangePassword int64) error UpdateStatus(ctx context.Context, id int64, status int64) error IncrementTokenVersion(ctx context.Context, id int64) (int64, error) IncrementTokenVersionIfMatch(ctx context.Context, id int64, username string, expected int64) (int64, error) } customSysUserModel struct { *defaultSysUserModel } ) func NewSysUserModel(conn sqlx.SqlConn, c cache.CacheConf, cachePrefix string, opts ...cache.Option) SysUserModel { return &customSysUserModel{ defaultSysUserModel: newSysUserModel(conn, c, cachePrefix, opts...), } } func (m *customSysUserModel) FindListByPage(ctx context.Context, page, pageSize int64) ([]*SysUser, int64, error) { var total int64 countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s", m.table) if err := m.QueryRowNoCacheCtx(ctx, &total, countQuery); err != nil { return nil, 0, err } var list []*SysUser query := fmt.Sprintf("SELECT %s FROM %s ORDER BY id DESC LIMIT ?,?", sysUserRows, m.table) if err := m.QueryRowsNoCacheCtx(ctx, &list, query, (page-1)*pageSize, pageSize); err != nil { return nil, 0, err } return list, total, nil } type UserWithMemberType struct { SysUser MemberType string `db:"memberType"` } func (m *customSysUserModel) FindListByProductMembers(ctx context.Context, productCode string, page, pageSize int64) ([]*SysUser, map[int64]string, int64, error) { memberTable := "`sys_product_member`" var total int64 countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s u INNER JOIN %s pm ON u.`id` = pm.`userId` WHERE pm.`productCode` = ?", m.table, memberTable) if err := m.QueryRowNoCacheCtx(ctx, &total, countQuery, productCode); err != nil { return nil, nil, 0, err } var list []*UserWithMemberType fields := strings.Join(sysUserFieldNames, ",u.") query := fmt.Sprintf("SELECT u.%s, pm.`memberType` FROM %s u INNER JOIN %s pm ON u.`id` = pm.`userId` WHERE pm.`productCode` = ? ORDER BY u.`id` DESC LIMIT ?,?", fields, m.table, memberTable) if err := m.QueryRowsNoCacheCtx(ctx, &list, query, productCode, (page-1)*pageSize, pageSize); err != nil { return nil, nil, 0, err } users := make([]*SysUser, len(list)) memberMap := make(map[int64]string, len(list)) for i, item := range list { users[i] = &item.SysUser memberMap[item.Id] = item.MemberType } return users, memberMap, total, nil } func (m *customSysUserModel) FindIdsByDeptId(ctx context.Context, deptId int64) ([]int64, error) { var ids []int64 query := fmt.Sprintf("SELECT `id` FROM %s WHERE `deptId` = ?", m.table) if err := m.QueryRowsNoCacheCtx(ctx, &ids, query, deptId); err != nil { return nil, err } return ids, nil } func (m *customSysUserModel) UpdateProfile(ctx context.Context, id int64, username string, nickname, email, phone, remark string, deptId, newStatus int64, statusChanged bool, expectedUpdateTime int64) error { sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id) sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, username) now := time.Now().Unix() res, err := m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (sql.Result, error) { if statusChanged { query := fmt.Sprintf("UPDATE %s SET `nickname`=?, `email`=?, `phone`=?, `remark`=?, `deptId`=?, `status`=?, `tokenVersion`=`tokenVersion`+1, `updateTime`=? WHERE `id`=? AND `updateTime`=?", m.table) return conn.ExecCtx(ctx, query, nickname, email, phone, remark, deptId, newStatus, now, id, expectedUpdateTime) } query := fmt.Sprintf("UPDATE %s SET `nickname`=?, `email`=?, `phone`=?, `remark`=?, `deptId`=?, `updateTime`=? WHERE `id`=? AND `updateTime`=?", m.table) return conn.ExecCtx(ctx, query, nickname, email, phone, remark, deptId, now, id, expectedUpdateTime) }, sysUserIdKey, sysUserUsernameKey) if err != nil { return err } affected, _ := res.RowsAffected() if affected == 0 { return ErrUpdateConflict } return nil } func (m *customSysUserModel) UpdatePassword(ctx context.Context, id int64, password string, mustChangePassword int64) error { data, err := m.FindOne(ctx, id) if err != nil { return err } sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id) sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, data.Username) _, err = m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (sql.Result, error) { query := fmt.Sprintf("UPDATE %s SET `password` = ?, `mustChangePassword` = ?, `tokenVersion` = `tokenVersion` + 1, `updateTime` = ? WHERE `id` = ?", m.table) return conn.ExecCtx(ctx, query, password, mustChangePassword, time.Now().Unix(), id) }, sysUserIdKey, sysUserUsernameKey) return err } func (m *customSysUserModel) UpdateStatus(ctx context.Context, id int64, status int64) error { data, err := m.FindOne(ctx, id) if err != nil { return err } sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id) sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, data.Username) _, err = m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (sql.Result, error) { query := fmt.Sprintf("UPDATE %s SET `status` = ?, `tokenVersion` = `tokenVersion` + 1, `updateTime` = ? WHERE `id` = ?", m.table) return conn.ExecCtx(ctx, query, status, time.Now().Unix(), id) }, sysUserIdKey, sysUserUsernameKey) return err } func (m *customSysUserModel) IncrementTokenVersion(ctx context.Context, id int64) (int64, error) { data, err := m.FindOne(ctx, id) if err != nil { return 0, err } sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id) sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, data.Username) var newVersion int64 err = m.TransactCtx(ctx, func(ctx context.Context, session sqlx.Session) error { query := fmt.Sprintf("UPDATE %s SET `tokenVersion` = LAST_INSERT_ID(`tokenVersion` + 1), `updateTime` = ? WHERE `id` = ?", m.table) if _, err := session.ExecCtx(ctx, query, time.Now().Unix(), id); err != nil { return err } return session.QueryRowCtx(ctx, &newVersion, "SELECT LAST_INSERT_ID()") }) if err != nil { return 0, err } _ = m.DelCacheCtx(ctx, sysUserIdKey, sysUserUsernameKey) return newVersion, nil } // IncrementTokenVersionIfMatch 原子递增 tokenVersion;仅当 DB 里当前 tokenVersion == expected 时才会生效。 // 这是 refreshToken rotation 的原子 CAS:两个并发的刷新请求只有一个能命中 WHERE tokenVersion=expected, // 另一个 affected=0 返回 ErrTokenVersionMismatch,从而避免"两边都换到新令牌"的会话劫持窗口。 // // 由上游透传 username 以便构造 cacheSysUserUsernamePrefix 的缓存键进行失效,避免为此多查一次 FindOne // (见审计 M-8)。上游通常已经通过 UserDetailsLoader.Load 拿到 username,零额外成本。 func (m *customSysUserModel) IncrementTokenVersionIfMatch(ctx context.Context, id int64, username string, expected int64) (int64, error) { sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id) sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, username) var newVersion int64 err := m.TransactCtx(ctx, func(ctx context.Context, session sqlx.Session) error { query := fmt.Sprintf("UPDATE %s SET `tokenVersion` = LAST_INSERT_ID(`tokenVersion` + 1), `updateTime` = ? WHERE `id` = ? AND `tokenVersion` = ?", m.table) res, err := session.ExecCtx(ctx, query, time.Now().Unix(), id, expected) if err != nil { return err } affected, _ := res.RowsAffected() if affected == 0 { return ErrTokenVersionMismatch } return session.QueryRowCtx(ctx, &newVersion, "SELECT LAST_INSERT_ID()") }) if err != nil { return 0, err } _ = m.DelCacheCtx(ctx, sysUserIdKey, sysUserUsernameKey) return newVersion, nil } func (m *customSysUserModel) FindByIds(ctx context.Context, ids []int64) ([]*SysUser, error) { if len(ids) == 0 { return nil, nil } placeholders := make([]string, len(ids)) args := make([]interface{}, len(ids)) for i, id := range ids { placeholders[i] = "?" args[i] = id } var list []*SysUser query := fmt.Sprintf("SELECT %s FROM %s WHERE `id` IN (%s)", sysUserRows, m.table, strings.Join(placeholders, ",")) if err := m.QueryRowsNoCacheCtx(ctx, &list, query, args...); err != nil { return nil, err } return list, nil }