package user import ( "context" "database/sql" "errors" "fmt" "strings" "time" "github.com/zeromicro/go-zero/core/stores/cache" "github.com/zeromicro/go-zero/core/stores/sqlx" ) var ErrUpdateConflict = errors.New("update conflict: data has been modified by another operation") var _ SysUserModel = (*customSysUserModel)(nil) type ( SysUserModel interface { sysUserModel FindListByPage(ctx context.Context, page, pageSize int64) ([]*SysUser, int64, error) FindListByProductMembers(ctx context.Context, productCode string, page, pageSize int64) ([]*SysUser, int64, error) FindByIds(ctx context.Context, ids []int64) ([]*SysUser, error) FindIdsByDeptId(ctx context.Context, deptId int64) ([]int64, error) UpdateProfile(ctx context.Context, id int64, username string, nickname, email, phone, remark string, deptId, newStatus int64, statusChanged bool, expectedUpdateTime int64) error UpdatePassword(ctx context.Context, id int64, password string, mustChangePassword int64) error UpdateStatus(ctx context.Context, id int64, status int64) error IncrementTokenVersion(ctx context.Context, id int64) (int64, error) } customSysUserModel struct { *defaultSysUserModel } ) func NewSysUserModel(conn sqlx.SqlConn, c cache.CacheConf, cachePrefix string, opts ...cache.Option) SysUserModel { return &customSysUserModel{ defaultSysUserModel: newSysUserModel(conn, c, cachePrefix, opts...), } } func (m *customSysUserModel) FindListByPage(ctx context.Context, page, pageSize int64) ([]*SysUser, int64, error) { var total int64 countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s", m.table) if err := m.QueryRowNoCacheCtx(ctx, &total, countQuery); err != nil { return nil, 0, err } var list []*SysUser query := fmt.Sprintf("SELECT %s FROM %s ORDER BY id DESC LIMIT ?,?", sysUserRows, m.table) if err := m.QueryRowsNoCacheCtx(ctx, &list, query, (page-1)*pageSize, pageSize); err != nil { return nil, 0, err } return list, total, nil } func (m *customSysUserModel) FindListByProductMembers(ctx context.Context, productCode string, page, pageSize int64) ([]*SysUser, int64, error) { memberTable := "`sys_product_member`" var total int64 countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s u INNER JOIN %s pm ON u.`id` = pm.`userId` WHERE pm.`productCode` = ?", m.table, memberTable) if err := m.QueryRowNoCacheCtx(ctx, &total, countQuery, productCode); err != nil { return nil, 0, err } var list []*SysUser fields := strings.Join(sysUserFieldNames, ",u.") query := fmt.Sprintf("SELECT u.%s FROM %s u INNER JOIN %s pm ON u.`id` = pm.`userId` WHERE pm.`productCode` = ? ORDER BY u.`id` DESC LIMIT ?,?", fields, m.table, memberTable) if err := m.QueryRowsNoCacheCtx(ctx, &list, query, productCode, (page-1)*pageSize, pageSize); err != nil { return nil, 0, err } return list, total, nil } func (m *customSysUserModel) FindIdsByDeptId(ctx context.Context, deptId int64) ([]int64, error) { var ids []int64 query := fmt.Sprintf("SELECT `id` FROM %s WHERE `deptId` = ?", m.table) if err := m.QueryRowsNoCacheCtx(ctx, &ids, query, deptId); err != nil { return nil, err } return ids, nil } func (m *customSysUserModel) UpdateProfile(ctx context.Context, id int64, username string, nickname, email, phone, remark string, deptId, newStatus int64, statusChanged bool, expectedUpdateTime int64) error { sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id) sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, username) now := time.Now().Unix() res, err := m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (sql.Result, error) { if statusChanged { query := fmt.Sprintf("UPDATE %s SET `nickname`=?, `email`=?, `phone`=?, `remark`=?, `deptId`=?, `status`=?, `tokenVersion`=`tokenVersion`+1, `updateTime`=? WHERE `id`=? AND `updateTime`=?", m.table) return conn.ExecCtx(ctx, query, nickname, email, phone, remark, deptId, newStatus, now, id, expectedUpdateTime) } query := fmt.Sprintf("UPDATE %s SET `nickname`=?, `email`=?, `phone`=?, `remark`=?, `deptId`=?, `updateTime`=? WHERE `id`=? AND `updateTime`=?", m.table) return conn.ExecCtx(ctx, query, nickname, email, phone, remark, deptId, now, id, expectedUpdateTime) }, sysUserIdKey, sysUserUsernameKey) if err != nil { return err } affected, _ := res.RowsAffected() if affected == 0 { return ErrUpdateConflict } return nil } func (m *customSysUserModel) UpdatePassword(ctx context.Context, id int64, password string, mustChangePassword int64) error { data, err := m.FindOne(ctx, id) if err != nil { return err } sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id) sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, data.Username) _, err = m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (sql.Result, error) { query := fmt.Sprintf("UPDATE %s SET `password` = ?, `mustChangePassword` = ?, `tokenVersion` = `tokenVersion` + 1, `updateTime` = ? WHERE `id` = ?", m.table) return conn.ExecCtx(ctx, query, password, mustChangePassword, time.Now().Unix(), id) }, sysUserIdKey, sysUserUsernameKey) return err } func (m *customSysUserModel) UpdateStatus(ctx context.Context, id int64, status int64) error { data, err := m.FindOne(ctx, id) if err != nil { return err } sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id) sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, data.Username) _, err = m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (sql.Result, error) { query := fmt.Sprintf("UPDATE %s SET `status` = ?, `tokenVersion` = `tokenVersion` + 1, `updateTime` = ? WHERE `id` = ?", m.table) return conn.ExecCtx(ctx, query, status, time.Now().Unix(), id) }, sysUserIdKey, sysUserUsernameKey) return err } func (m *customSysUserModel) IncrementTokenVersion(ctx context.Context, id int64) (int64, error) { data, err := m.FindOne(ctx, id) if err != nil { return 0, err } sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id) sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, data.Username) _, err = m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (sql.Result, error) { query := fmt.Sprintf("UPDATE %s SET `tokenVersion` = `tokenVersion` + 1, `updateTime` = ? WHERE `id` = ?", m.table) return conn.ExecCtx(ctx, query, time.Now().Unix(), id) }, sysUserIdKey, sysUserUsernameKey) if err != nil { return 0, err } return data.TokenVersion + 1, nil } func (m *customSysUserModel) FindByIds(ctx context.Context, ids []int64) ([]*SysUser, error) { if len(ids) == 0 { return nil, nil } placeholders := make([]string, len(ids)) args := make([]interface{}, len(ids)) for i, id := range ids { placeholders[i] = "?" args[i] = id } var list []*SysUser query := fmt.Sprintf("SELECT %s FROM %s WHERE `id` IN (%s)", sysUserRows, m.table, strings.Join(placeholders, ",")) if err := m.QueryRowsNoCacheCtx(ctx, &list, query, args...); err != nil { return nil, err } return list, nil }