package role import ( "context" "database/sql" "errors" "fmt" "sort" "strings" "perms-system-server/internal/consts" "github.com/zeromicro/go-zero/core/stores/cache" "github.com/zeromicro/go-zero/core/stores/sqlx" ) var ErrUpdateConflict = errors.New("update conflict: data has been modified by another operation") var _ SysRoleModel = (*customSysRoleModel)(nil) type ( SysRoleModel interface { sysRoleModel FindListByProductCode(ctx context.Context, productCode string, page, pageSize int64) ([]*SysRole, int64, error) FindByIds(ctx context.Context, ids []int64) ([]*SysRole, error) FindMinPermsLevelByUserIdAndProductCode(ctx context.Context, userId int64, productCode string) (int64, error) UpdateWithOptLock(ctx context.Context, data *SysRole, expectedUpdateTime int64) error // LockByIdTx 在当前事务里锁住 sys_role 行(SELECT ... FOR UPDATE),用于把"同一 role 的 // BindRolePerms 并发覆盖"串行化,消除"existing 在事务外读 + 事务内 delete/insert" // 造成的第三态合并问题(见审计 M-R10-2)。 LockByIdTx(ctx context.Context, session sqlx.Session, id int64) (*SysRole, error) // LockRolesForShareTx 在当前事务里对一批 sys_role 行取 S 锁(SELECT ... LOCK IN SHARE MODE), // 用于闭合 BindRoles × DeleteRole 的写偏斜(审计 M-R12-1): // DeleteRole 在事务末尾对 sys_role[R] 取 X 锁,会被本 S 锁阻塞;等 BindRoles 提交后 // DeleteRole 再去 FindUserIdsByRoleIdForUpdateTx 时即能看到新插入的 sys_user_role 行, // 可选择抛错阻断删除或一并清理,不再留下 roleId 指向已删 sys_role 的孤儿。 // 若命中行数 != len(ids) 或有行 status != Enabled,返回 sqlx.ErrNotFound 让调用方 // 转成 400 "包含无效的角色ID"——因为 DeleteRole 的删除发生在 sys_role 行被移走、 // 或者 UpdateRole 把角色 status 改为 Disabled,业务上都不应再绑定。 // 本方法不走缓存,必须在 TransactCtx / Session 下调用;入参 ids 会在内部按升序排序 // 取锁以避免死锁。 LockRolesForShareTx(ctx context.Context, session sqlx.Session, ids []int64) error } customSysRoleModel struct { *defaultSysRoleModel } ) func NewSysRoleModel(conn sqlx.SqlConn, c cache.CacheConf, cachePrefix string, opts ...cache.Option) SysRoleModel { return &customSysRoleModel{ defaultSysRoleModel: newSysRoleModel(conn, c, cachePrefix, opts...), } } func (m *customSysRoleModel) FindListByProductCode(ctx context.Context, productCode string, page, pageSize int64) ([]*SysRole, int64, error) { var total int64 countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE `productCode` = ?", m.table) if err := m.QueryRowNoCacheCtx(ctx, &total, countQuery, productCode); err != nil { return nil, 0, err } var list []*SysRole query := fmt.Sprintf("SELECT %s FROM %s WHERE `productCode` = ? ORDER BY `permsLevel` ASC, id DESC LIMIT ?,?", sysRoleRows, m.table) if err := m.QueryRowsNoCacheCtx(ctx, &list, query, productCode, (page-1)*pageSize, pageSize); err != nil { return nil, 0, err } return list, total, nil } func (m *customSysRoleModel) FindByIds(ctx context.Context, ids []int64) ([]*SysRole, error) { if len(ids) == 0 { return nil, nil } args := make([]interface{}, len(ids)) marks := make([]string, len(ids)) for i, id := range ids { args[i] = id marks[i] = "?" } var list []*SysRole query := fmt.Sprintf("SELECT %s FROM %s WHERE `id` IN (%s)", sysRoleRows, m.table, strings.Join(marks, ",")) if err := m.QueryRowsNoCacheCtx(ctx, &list, query, args...); err != nil { return nil, err } return list, nil } func (m *customSysRoleModel) UpdateWithOptLock(ctx context.Context, data *SysRole, expectedUpdateTime int64) error { sysRoleIdKey := fmt.Sprintf("%s%v", cacheSysRoleIdPrefix, data.Id) sysRoleProductCodeNameKey := fmt.Sprintf("%s%v:%v", cacheSysRoleProductCodeNamePrefix, data.ProductCode, data.Name) res, err := m.ExecCtx(ctx, func(ctx context.Context, conn sqlx.SqlConn) (sql.Result, error) { query := fmt.Sprintf("UPDATE %s SET `name`=?, `remark`=?, `status`=?, `permsLevel`=?, `updateTime`=? WHERE `id`=? AND `updateTime`=?", m.table) return conn.ExecCtx(ctx, query, data.Name, data.Remark, data.Status, data.PermsLevel, data.UpdateTime, data.Id, expectedUpdateTime) }, sysRoleIdKey, sysRoleProductCodeNameKey) if err != nil { return err } affected, _ := res.RowsAffected() if affected == 0 { return ErrUpdateConflict } return nil } // LockRolesForShareTx 见接口注释(审计 M-R12-1)。 func (m *customSysRoleModel) LockRolesForShareTx(ctx context.Context, session sqlx.Session, ids []int64) error { if len(ids) == 0 { return nil } // 去重 + 升序,避免同一事务重复 SELECT 相同 id 造成的等待链加长,并保证多条 BindRoles // 并发时按统一顺序取锁(避免 A 锁 1→2、B 锁 2→1 的死锁)。 seen := make(map[int64]struct{}, len(ids)) sorted := make([]int64, 0, len(ids)) for _, id := range ids { if _, ok := seen[id]; ok { continue } seen[id] = struct{}{} sorted = append(sorted, id) } sort.Slice(sorted, func(i, j int) bool { return sorted[i] < sorted[j] }) placeholders := make([]string, len(sorted)) args := make([]interface{}, 0, len(sorted)+1) for i, id := range sorted { placeholders[i] = "?" args = append(args, id) } args = append(args, consts.StatusEnabled) var lockedIds []int64 query := fmt.Sprintf( "SELECT `id` FROM %s WHERE `id` IN (%s) AND `status` = ? ORDER BY `id` LOCK IN SHARE MODE", m.table, strings.Join(placeholders, ","), ) if err := session.QueryRowsCtx(ctx, &lockedIds, query, args...); err != nil { return err } // 任一 id 对不上(已被 DeleteRole 删掉、或 UpdateRole 改为 Disabled)都一刀切回 ErrNotFound, // 让调用方 BindRoles 立即终止事务并返回 400;不在本函数里做"部分成功 + 分辨哪些失败"的返回值 // 语义(DeleteRole 对 sys_role 的 X 锁在本事务提交前不会释放,所以本 S 锁"全捕获"才是正确信号)。 if len(lockedIds) != len(sorted) { return sqlx.ErrNotFound } return nil } // LockByIdTx 见接口注释。注意:本函数不走缓存层,必须在 TransactCtx / Session 下调用; // SELECT ... FOR UPDATE 的行锁由 InnoDB 持有到事务结束。 func (m *customSysRoleModel) LockByIdTx(ctx context.Context, session sqlx.Session, id int64) (*SysRole, error) { var data SysRole query := fmt.Sprintf("SELECT %s FROM %s WHERE `id` = ? LIMIT 1 FOR UPDATE", sysRoleRows, m.table) if err := session.QueryRowCtx(ctx, &data, query, id); err != nil { return nil, err } return &data, nil } func (m *customSysRoleModel) FindMinPermsLevelByUserIdAndProductCode(ctx context.Context, userId int64, productCode string) (int64, error) { var level int64 query := fmt.Sprintf( "SELECT IFNULL(MIN(r.`permsLevel`), -1) FROM %s r INNER JOIN `sys_user_role` ur ON r.`id` = ur.`roleId` WHERE ur.`userId` = ? AND r.`productCode` = ? AND r.`status` = ?", m.table, ) if err := m.QueryRowNoCacheCtx(ctx, &level, query, userId, productCode, consts.StatusEnabled); err != nil { return 0, err } if level < 0 { return 0, ErrNotFound } return level, nil }