ratelimitMiddleware.go 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. package middleware
  2. import (
  3. "fmt"
  4. "net"
  5. "net/http"
  6. "perms-system-server/internal/response"
  7. "github.com/zeromicro/go-zero/core/limit"
  8. "github.com/zeromicro/go-zero/core/stores/redis"
  9. "github.com/zeromicro/go-zero/rest/httpx"
  10. )
  11. type RateLimitMiddleware struct {
  12. limiter *limit.PeriodLimit
  13. behindProxy bool
  14. }
  15. func NewRateLimitMiddleware(rds *redis.Redis, period int, quota int, keyPrefix string, behindProxy bool) *RateLimitMiddleware {
  16. limiter := limit.NewPeriodLimit(period, quota, rds, keyPrefix)
  17. return &RateLimitMiddleware{limiter: limiter, behindProxy: behindProxy}
  18. }
  19. func (m *RateLimitMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
  20. return func(w http.ResponseWriter, r *http.Request) {
  21. ip := ExtractClientIP(r, m.behindProxy)
  22. key := fmt.Sprintf("ip:%s", ip)
  23. code, _ := m.limiter.Take(key)
  24. if code == limit.OverQuota {
  25. httpx.ErrorCtx(r.Context(), w, response.ErrTooManyRequests("请求过于频繁,请稍后再试"))
  26. return
  27. }
  28. next(w, r)
  29. }
  30. }
  31. // ExtractClientIP extracts client IP from the request.
  32. // When behindProxy is true, it trusts X-Real-IP header set by the reverse proxy.
  33. // When false, it only uses RemoteAddr for security.
  34. func ExtractClientIP(r *http.Request, behindProxy bool) string {
  35. if behindProxy {
  36. if ip := r.Header.Get("X-Real-IP"); ip != "" {
  37. return ip
  38. }
  39. }
  40. host, _, err := net.SplitHostPort(r.RemoteAddr)
  41. if err != nil {
  42. return r.RemoteAddr
  43. }
  44. return host
  45. }