ratelimitMiddleware.go 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. package middleware
  2. import (
  3. "context"
  4. "fmt"
  5. "net"
  6. "net/http"
  7. "perms-system-server/internal/response"
  8. "github.com/zeromicro/go-zero/core/limit"
  9. "github.com/zeromicro/go-zero/core/stores/redis"
  10. "github.com/zeromicro/go-zero/rest/httpx"
  11. )
  12. const ctxKeyClientIP contextKey = "clientIP"
  13. func WithClientIP(ctx context.Context, ip string) context.Context {
  14. return context.WithValue(ctx, ctxKeyClientIP, ip)
  15. }
  16. func GetClientIP(ctx context.Context) string {
  17. ip, _ := ctx.Value(ctxKeyClientIP).(string)
  18. return ip
  19. }
  20. type RateLimitMiddleware struct {
  21. limiter *limit.PeriodLimit
  22. behindProxy bool
  23. }
  24. func NewRateLimitMiddleware(rds *redis.Redis, period int, quota int, keyPrefix string, behindProxy bool) *RateLimitMiddleware {
  25. limiter := limit.NewPeriodLimit(period, quota, rds, keyPrefix)
  26. return &RateLimitMiddleware{limiter: limiter, behindProxy: behindProxy}
  27. }
  28. func (m *RateLimitMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
  29. return func(w http.ResponseWriter, r *http.Request) {
  30. ip := ExtractClientIP(r, m.behindProxy)
  31. key := fmt.Sprintf("ip:%s", ip)
  32. code, _ := m.limiter.Take(key)
  33. if code == limit.OverQuota {
  34. httpx.ErrorCtx(r.Context(), w, response.ErrTooManyRequests("请求过于频繁,请稍后再试"))
  35. return
  36. }
  37. ctx := WithClientIP(r.Context(), ip)
  38. next(w, r.WithContext(ctx))
  39. }
  40. }
  41. // ExtractClientIP extracts client IP from the request.
  42. // When behindProxy is true, it trusts X-Real-IP header set by the reverse proxy.
  43. // When false, it only uses RemoteAddr for security.
  44. func ExtractClientIP(r *http.Request, behindProxy bool) string {
  45. if behindProxy {
  46. if ip := r.Header.Get("X-Real-IP"); ip != "" {
  47. return ip
  48. }
  49. }
  50. host, _, err := net.SplitHostPort(r.RemoteAddr)
  51. if err != nil {
  52. return r.RemoteAddr
  53. }
  54. return host
  55. }