package user import ( "context" "database/sql" "fmt" "strings" "time" "github.com/zeromicro/go-zero/core/stores/cache" "github.com/zeromicro/go-zero/core/stores/sqlx" ) var _ SysUserModel = (*customSysUserModel)(nil) type ( SysUserModel interface { sysUserModel FindListByPage(ctx context.Context, page, pageSize int64) ([]*SysUser, int64, error) FindListByProductMembers(ctx context.Context, productCode string, page, pageSize int64) ([]*SysUser, int64, error) FindListByDeptIds(ctx context.Context, deptIds []int64, page, pageSize int64) ([]*SysUser, int64, error) FindByIds(ctx context.Context, ids []int64) ([]*SysUser, error) FindIdsByDeptId(ctx context.Context, deptId int64) ([]int64, error) UpdatePassword(ctx context.Context, id int64, password string, mustChangePassword int64) error UpdateStatus(ctx context.Context, id int64, status int64) error } customSysUserModel struct { *defaultSysUserModel } ) func NewSysUserModel(conn sqlx.SqlConn, c cache.CacheConf, cachePrefix string, opts ...cache.Option) SysUserModel { return &customSysUserModel{ defaultSysUserModel: newSysUserModel(conn, c, cachePrefix, opts...), } } func (m *customSysUserModel) FindListByPage(ctx context.Context, page, pageSize int64) ([]*SysUser, int64, error) { var total int64 countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s", m.table) if err := m.QueryRowNoCacheCtx(ctx, &total, countQuery); err != nil { return nil, 0, err } var list []*SysUser query := fmt.Sprintf("SELECT %s FROM %s ORDER BY id DESC LIMIT ?,?", sysUserRows, m.table) if err := m.QueryRowsNoCacheCtx(ctx, &list, query, (page-1)*pageSize, pageSize); err != nil { return nil, 0, err } return list, total, nil } func (m *customSysUserModel) FindListByProductMembers(ctx context.Context, productCode string, page, pageSize int64) ([]*SysUser, int64, error) { memberTable := "`sys_product_member`" var total int64 countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s u INNER JOIN %s pm ON u.`id` = pm.`userId` WHERE pm.`productCode` = ?", m.table, memberTable) if err := m.QueryRowNoCacheCtx(ctx, &total, countQuery, productCode); err != nil { return nil, 0, err } var list []*SysUser fields := strings.Join(sysUserFieldNames, ",u.") query := fmt.Sprintf("SELECT u.%s FROM %s u INNER JOIN %s pm ON u.`id` = pm.`userId` WHERE pm.`productCode` = ? ORDER BY u.`id` DESC LIMIT ?,?", fields, m.table, memberTable) if err := m.QueryRowsNoCacheCtx(ctx, &list, query, productCode, (page-1)*pageSize, pageSize); err != nil { return nil, 0, err } return list, total, nil } func (m *customSysUserModel) FindListByDeptIds(ctx context.Context, deptIds []int64, page, pageSize int64) ([]*SysUser, int64, error) { if len(deptIds) == 0 { return nil, 0, nil } placeholders := make([]string, len(deptIds)) args := make([]interface{}, len(deptIds)) for i, id := range deptIds { placeholders[i] = "?" args[i] = id } inClause := strings.Join(placeholders, ",") var total int64 countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE `deptId` IN (%s)", m.table, inClause) if err := m.QueryRowNoCacheCtx(ctx, &total, countQuery, args...); err != nil { return nil, 0, err } var list []*SysUser pageArgs := make([]interface{}, len(args), len(args)+2) copy(pageArgs, args) pageArgs = append(pageArgs, (page-1)*pageSize, pageSize) query := fmt.Sprintf("SELECT %s FROM %s WHERE `deptId` IN (%s) ORDER BY id DESC LIMIT ?,?", sysUserRows, m.table, inClause) if err := m.QueryRowsNoCacheCtx(ctx, &list, query, pageArgs...); err != nil { return nil, 0, err } return list, total, nil } func (m *customSysUserModel) FindIdsByDeptId(ctx context.Context, deptId int64) ([]int64, error) { var ids []int64 query := fmt.Sprintf("SELECT `id` FROM %s WHERE `deptId` = ?", m.table) if err := m.QueryRowsNoCacheCtx(ctx, &ids, query, deptId); err != nil { return nil, err } return ids, nil } func (m *customSysUserModel) UpdatePassword(ctx context.Context, id int64, password string, mustChangePassword int64) error { data, err := m.FindOne(ctx, id) if err != nil { return err } sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id) sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, data.Username) _, err = m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (sql.Result, error) { query := fmt.Sprintf("UPDATE %s SET `password` = ?, `mustChangePassword` = ?, `tokenVersion` = `tokenVersion` + 1, `updateTime` = ? WHERE `id` = ?", m.table) return conn.ExecCtx(ctx, query, password, mustChangePassword, time.Now().Unix(), id) }, sysUserIdKey, sysUserUsernameKey) return err } func (m *customSysUserModel) UpdateStatus(ctx context.Context, id int64, status int64) error { data, err := m.FindOne(ctx, id) if err != nil { return err } sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id) sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, data.Username) _, err = m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (sql.Result, error) { query := fmt.Sprintf("UPDATE %s SET `status` = ?, `tokenVersion` = `tokenVersion` + 1, `updateTime` = ? WHERE `id` = ?", m.table) return conn.ExecCtx(ctx, query, status, time.Now().Unix(), id) }, sysUserIdKey, sysUserUsernameKey) return err } func (m *customSysUserModel) FindByIds(ctx context.Context, ids []int64) ([]*SysUser, error) { if len(ids) == 0 { return nil, nil } placeholders := make([]string, len(ids)) args := make([]interface{}, len(ids)) for i, id := range ids { placeholders[i] = "?" args[i] = id } var list []*SysUser query := fmt.Sprintf("SELECT %s FROM %s WHERE `id` IN (%s)", sysUserRows, m.table, strings.Join(placeholders, ",")) if err := m.QueryRowsNoCacheCtx(ctx, &list, query, args...); err != nil { return nil, err } return list, nil }