sysUserModel.go 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. package user
  2. import (
  3. "context"
  4. "database/sql"
  5. "errors"
  6. "fmt"
  7. "strings"
  8. "time"
  9. "github.com/zeromicro/go-zero/core/stores/cache"
  10. "github.com/zeromicro/go-zero/core/stores/sqlx"
  11. )
  12. var ErrUpdateConflict = errors.New("update conflict: data has been modified by another operation")
  13. var _ SysUserModel = (*customSysUserModel)(nil)
  14. type (
  15. SysUserModel interface {
  16. sysUserModel
  17. FindListByPage(ctx context.Context, page, pageSize int64) ([]*SysUser, int64, error)
  18. FindListByProductMembers(ctx context.Context, productCode string, page, pageSize int64) ([]*SysUser, map[int64]string, int64, error)
  19. FindByIds(ctx context.Context, ids []int64) ([]*SysUser, error)
  20. FindIdsByDeptId(ctx context.Context, deptId int64) ([]int64, error)
  21. UpdateProfile(ctx context.Context, id int64, username string, nickname, email, phone, remark string, deptId, newStatus int64, statusChanged bool, expectedUpdateTime int64) error
  22. UpdatePassword(ctx context.Context, id int64, password string, mustChangePassword int64) error
  23. UpdateStatus(ctx context.Context, id int64, status int64) error
  24. IncrementTokenVersion(ctx context.Context, id int64) (int64, error)
  25. }
  26. customSysUserModel struct {
  27. *defaultSysUserModel
  28. }
  29. )
  30. func NewSysUserModel(conn sqlx.SqlConn, c cache.CacheConf, cachePrefix string, opts ...cache.Option) SysUserModel {
  31. return &customSysUserModel{
  32. defaultSysUserModel: newSysUserModel(conn, c, cachePrefix, opts...),
  33. }
  34. }
  35. func (m *customSysUserModel) FindListByPage(ctx context.Context, page, pageSize int64) ([]*SysUser, int64, error) {
  36. var total int64
  37. countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s", m.table)
  38. if err := m.QueryRowNoCacheCtx(ctx, &total, countQuery); err != nil {
  39. return nil, 0, err
  40. }
  41. var list []*SysUser
  42. query := fmt.Sprintf("SELECT %s FROM %s ORDER BY id DESC LIMIT ?,?", sysUserRows, m.table)
  43. if err := m.QueryRowsNoCacheCtx(ctx, &list, query, (page-1)*pageSize, pageSize); err != nil {
  44. return nil, 0, err
  45. }
  46. return list, total, nil
  47. }
  48. type UserWithMemberType struct {
  49. SysUser
  50. MemberType string `db:"memberType"`
  51. }
  52. func (m *customSysUserModel) FindListByProductMembers(ctx context.Context, productCode string, page, pageSize int64) ([]*SysUser, map[int64]string, int64, error) {
  53. memberTable := "`sys_product_member`"
  54. var total int64
  55. countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s u INNER JOIN %s pm ON u.`id` = pm.`userId` WHERE pm.`productCode` = ?", m.table, memberTable)
  56. if err := m.QueryRowNoCacheCtx(ctx, &total, countQuery, productCode); err != nil {
  57. return nil, nil, 0, err
  58. }
  59. var list []*UserWithMemberType
  60. fields := strings.Join(sysUserFieldNames, ",u.")
  61. query := fmt.Sprintf("SELECT u.%s, pm.`memberType` FROM %s u INNER JOIN %s pm ON u.`id` = pm.`userId` WHERE pm.`productCode` = ? ORDER BY u.`id` DESC LIMIT ?,?", fields, m.table, memberTable)
  62. if err := m.QueryRowsNoCacheCtx(ctx, &list, query, productCode, (page-1)*pageSize, pageSize); err != nil {
  63. return nil, nil, 0, err
  64. }
  65. users := make([]*SysUser, len(list))
  66. memberMap := make(map[int64]string, len(list))
  67. for i, item := range list {
  68. users[i] = &item.SysUser
  69. memberMap[item.Id] = item.MemberType
  70. }
  71. return users, memberMap, total, nil
  72. }
  73. func (m *customSysUserModel) FindIdsByDeptId(ctx context.Context, deptId int64) ([]int64, error) {
  74. var ids []int64
  75. query := fmt.Sprintf("SELECT `id` FROM %s WHERE `deptId` = ?", m.table)
  76. if err := m.QueryRowsNoCacheCtx(ctx, &ids, query, deptId); err != nil {
  77. return nil, err
  78. }
  79. return ids, nil
  80. }
  81. func (m *customSysUserModel) UpdateProfile(ctx context.Context, id int64, username string, nickname, email, phone, remark string, deptId, newStatus int64, statusChanged bool, expectedUpdateTime int64) error {
  82. sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id)
  83. sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, username)
  84. now := time.Now().Unix()
  85. res, err := m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (sql.Result, error) {
  86. if statusChanged {
  87. query := fmt.Sprintf("UPDATE %s SET `nickname`=?, `email`=?, `phone`=?, `remark`=?, `deptId`=?, `status`=?, `tokenVersion`=`tokenVersion`+1, `updateTime`=? WHERE `id`=? AND `updateTime`=?", m.table)
  88. return conn.ExecCtx(ctx, query, nickname, email, phone, remark, deptId, newStatus, now, id, expectedUpdateTime)
  89. }
  90. query := fmt.Sprintf("UPDATE %s SET `nickname`=?, `email`=?, `phone`=?, `remark`=?, `deptId`=?, `updateTime`=? WHERE `id`=? AND `updateTime`=?", m.table)
  91. return conn.ExecCtx(ctx, query, nickname, email, phone, remark, deptId, now, id, expectedUpdateTime)
  92. }, sysUserIdKey, sysUserUsernameKey)
  93. if err != nil {
  94. return err
  95. }
  96. affected, _ := res.RowsAffected()
  97. if affected == 0 {
  98. return ErrUpdateConflict
  99. }
  100. return nil
  101. }
  102. func (m *customSysUserModel) UpdatePassword(ctx context.Context, id int64, password string, mustChangePassword int64) error {
  103. data, err := m.FindOne(ctx, id)
  104. if err != nil {
  105. return err
  106. }
  107. sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id)
  108. sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, data.Username)
  109. _, err = m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (sql.Result, error) {
  110. query := fmt.Sprintf("UPDATE %s SET `password` = ?, `mustChangePassword` = ?, `tokenVersion` = `tokenVersion` + 1, `updateTime` = ? WHERE `id` = ?", m.table)
  111. return conn.ExecCtx(ctx, query, password, mustChangePassword, time.Now().Unix(), id)
  112. }, sysUserIdKey, sysUserUsernameKey)
  113. return err
  114. }
  115. func (m *customSysUserModel) UpdateStatus(ctx context.Context, id int64, status int64) error {
  116. data, err := m.FindOne(ctx, id)
  117. if err != nil {
  118. return err
  119. }
  120. sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id)
  121. sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, data.Username)
  122. _, err = m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (sql.Result, error) {
  123. query := fmt.Sprintf("UPDATE %s SET `status` = ?, `tokenVersion` = `tokenVersion` + 1, `updateTime` = ? WHERE `id` = ?", m.table)
  124. return conn.ExecCtx(ctx, query, status, time.Now().Unix(), id)
  125. }, sysUserIdKey, sysUserUsernameKey)
  126. return err
  127. }
  128. func (m *customSysUserModel) IncrementTokenVersion(ctx context.Context, id int64) (int64, error) {
  129. data, err := m.FindOne(ctx, id)
  130. if err != nil {
  131. return 0, err
  132. }
  133. sysUserIdKey := fmt.Sprintf("%s%v", cacheSysUserIdPrefix, id)
  134. sysUserUsernameKey := fmt.Sprintf("%s%v", cacheSysUserUsernamePrefix, data.Username)
  135. var newVersion int64
  136. err = m.TransactCtx(ctx, func(ctx context.Context, session sqlx.Session) error {
  137. query := fmt.Sprintf("UPDATE %s SET `tokenVersion` = LAST_INSERT_ID(`tokenVersion` + 1), `updateTime` = ? WHERE `id` = ?", m.table)
  138. if _, err := session.ExecCtx(ctx, query, time.Now().Unix(), id); err != nil {
  139. return err
  140. }
  141. return session.QueryRowCtx(ctx, &newVersion, "SELECT LAST_INSERT_ID()")
  142. })
  143. if err != nil {
  144. return 0, err
  145. }
  146. _ = m.DelCacheCtx(ctx, sysUserIdKey, sysUserUsernameKey)
  147. return newVersion, nil
  148. }
  149. func (m *customSysUserModel) FindByIds(ctx context.Context, ids []int64) ([]*SysUser, error) {
  150. if len(ids) == 0 {
  151. return nil, nil
  152. }
  153. placeholders := make([]string, len(ids))
  154. args := make([]interface{}, len(ids))
  155. for i, id := range ids {
  156. placeholders[i] = "?"
  157. args[i] = id
  158. }
  159. var list []*SysUser
  160. query := fmt.Sprintf("SELECT %s FROM %s WHERE `id` IN (%s)", sysUserRows, m.table, strings.Join(placeholders, ","))
  161. if err := m.QueryRowsNoCacheCtx(ctx, &list, query, args...); err != nil {
  162. return nil, err
  163. }
  164. return list, nil
  165. }