package productmember import ( "context" "fmt" "strings" "perms-system-server/internal/consts" "github.com/zeromicro/go-zero/core/stores/cache" "github.com/zeromicro/go-zero/core/stores/sqlx" ) var _ SysProductMemberModel = (*customSysProductMemberModel)(nil) type ( SysProductMemberModel interface { sysProductMemberModel FindListByProductCode(ctx context.Context, productCode string, page, pageSize int64) ([]*SysProductMember, int64, error) FindMapByProductCodeUserIds(ctx context.Context, productCode string, userIds []int64) (map[int64]*SysProductMember, error) CountActiveAdmins(ctx context.Context, productCode string) (int64, error) CountActiveAdminsTx(ctx context.Context, session sqlx.Session, productCode string) (int64, error) CountOtherActiveAdminsTx(ctx context.Context, session sqlx.Session, productCode string, excludeId int64) (int64, error) FindOneForUpdateTx(ctx context.Context, session sqlx.Session, id int64) (*SysProductMember, error) } customSysProductMemberModel struct { *defaultSysProductMemberModel } ) func NewSysProductMemberModel(conn sqlx.SqlConn, c cache.CacheConf, cachePrefix string, opts ...cache.Option) SysProductMemberModel { return &customSysProductMemberModel{ defaultSysProductMemberModel: newSysProductMemberModel(conn, c, cachePrefix, opts...), } } func (m *customSysProductMemberModel) FindListByProductCode(ctx context.Context, productCode string, page, pageSize int64) ([]*SysProductMember, int64, error) { var total int64 countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE `productCode` = ?", m.table) if err := m.QueryRowNoCacheCtx(ctx, &total, countQuery, productCode); err != nil { return nil, 0, err } var list []*SysProductMember query := fmt.Sprintf("SELECT %s FROM %s WHERE `productCode` = ? ORDER BY id DESC LIMIT ?,?", sysProductMemberRows, m.table) if err := m.QueryRowsNoCacheCtx(ctx, &list, query, productCode, (page-1)*pageSize, pageSize); err != nil { return nil, 0, err } return list, total, nil } func (m *customSysProductMemberModel) CountActiveAdmins(ctx context.Context, productCode string) (int64, error) { var count int64 query := fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE `productCode` = ? AND `memberType` = ? AND `status` = ?", m.table) if err := m.QueryRowNoCacheCtx(ctx, &count, query, productCode, consts.MemberTypeAdmin, consts.StatusEnabled); err != nil { return 0, err } return count, nil } func (m *customSysProductMemberModel) CountActiveAdminsTx(ctx context.Context, session sqlx.Session, productCode string) (int64, error) { var ids []int64 query := fmt.Sprintf("SELECT `id` FROM %s WHERE `productCode` = ? AND `memberType` = ? AND `status` = ? FOR UPDATE", m.table) if err := session.QueryRowsCtx(ctx, &ids, query, productCode, consts.MemberTypeAdmin, consts.StatusEnabled); err != nil { return 0, err } return int64(len(ids)), nil } // CountOtherActiveAdminsTx 统计"除 excludeId 这一行以外"的启用 ADMIN 数量。调用方一般把即将被删除 // 或即将被降级的目标行 id 传进来;返回 0 即表示目标是最后一个 active admin,不能动。相比 // CountActiveAdminsTx + adminCount <= 1 的反向推理,语义更贴合业务(见审计 L-5)。 // 仍然使用 FOR UPDATE 锁住扫描范围,串行化与并发降级/删除的冲突。 func (m *customSysProductMemberModel) CountOtherActiveAdminsTx(ctx context.Context, session sqlx.Session, productCode string, excludeId int64) (int64, error) { var ids []int64 query := fmt.Sprintf("SELECT `id` FROM %s WHERE `productCode` = ? AND `memberType` = ? AND `status` = ? AND `id` != ? FOR UPDATE", m.table) if err := session.QueryRowsCtx(ctx, &ids, query, productCode, consts.MemberTypeAdmin, consts.StatusEnabled, excludeId); err != nil { return 0, err } return int64(len(ids)), nil } func (m *customSysProductMemberModel) FindOneForUpdateTx(ctx context.Context, session sqlx.Session, id int64) (*SysProductMember, error) { var data SysProductMember query := fmt.Sprintf("SELECT %s FROM %s WHERE `id` = ? FOR UPDATE", sysProductMemberRows, m.table) if err := session.QueryRowCtx(ctx, &data, query, id); err != nil { return nil, err } return &data, nil } func (m *customSysProductMemberModel) FindMapByProductCodeUserIds(ctx context.Context, productCode string, userIds []int64) (map[int64]*SysProductMember, error) { if len(userIds) == 0 { return make(map[int64]*SysProductMember), nil } placeholders := make([]string, len(userIds)) args := make([]interface{}, 0, len(userIds)+1) args = append(args, productCode) for i, id := range userIds { placeholders[i] = "?" args = append(args, id) } var list []*SysProductMember query := fmt.Sprintf("SELECT %s FROM %s WHERE `productCode` = ? AND `userId` IN (%s)", sysProductMemberRows, m.table, strings.Join(placeholders, ",")) if err := m.QueryRowsNoCacheCtx(ctx, &list, query, args...); err != nil { return nil, err } result := make(map[int64]*SysProductMember, len(list)) for _, pm := range list { result[pm.UserId] = pm } return result, nil }