package middleware import ( "fmt" "net" "net/http" "perms-system-server/internal/response" "github.com/zeromicro/go-zero/core/limit" "github.com/zeromicro/go-zero/core/stores/redis" "github.com/zeromicro/go-zero/rest/httpx" ) type RateLimitMiddleware struct { limiter *limit.PeriodLimit behindProxy bool } func NewRateLimitMiddleware(rds *redis.Redis, period int, quota int, keyPrefix string, behindProxy bool) *RateLimitMiddleware { limiter := limit.NewPeriodLimit(period, quota, rds, keyPrefix) return &RateLimitMiddleware{limiter: limiter, behindProxy: behindProxy} } func (m *RateLimitMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ip := ExtractClientIP(r, m.behindProxy) key := fmt.Sprintf("ip:%s", ip) code, _ := m.limiter.Take(key) if code == limit.OverQuota { httpx.ErrorCtx(r.Context(), w, response.ErrTooManyRequests("请求过于频繁,请稍后再试")) return } next(w, r) } } // ExtractClientIP extracts client IP from the request. // When behindProxy is true, it trusts X-Real-IP header set by the reverse proxy. // When false, it only uses RemoteAddr for security. func ExtractClientIP(r *http.Request, behindProxy bool) string { if behindProxy { if ip := r.Header.Get("X-Real-IP"); ip != "" { return ip } } host, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { return r.RemoteAddr } return host }