ratelimitMiddleware.go 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. package middleware
  2. import (
  3. "context"
  4. "fmt"
  5. "net"
  6. "net/http"
  7. "strings"
  8. "perms-system-server/internal/response"
  9. "github.com/zeromicro/go-zero/core/limit"
  10. "github.com/zeromicro/go-zero/core/logx"
  11. "github.com/zeromicro/go-zero/core/stores/redis"
  12. "github.com/zeromicro/go-zero/rest/httpx"
  13. )
  14. const ctxKeyClientIP contextKey = "clientIP"
  15. func WithClientIP(ctx context.Context, ip string) context.Context {
  16. return context.WithValue(ctx, ctxKeyClientIP, ip)
  17. }
  18. func GetClientIP(ctx context.Context) string {
  19. ip, _ := ctx.Value(ctxKeyClientIP).(string)
  20. return ip
  21. }
  22. type RateLimitMiddleware struct {
  23. limiter *limit.PeriodLimit
  24. behindProxy bool
  25. }
  26. func NewRateLimitMiddleware(rds *redis.Redis, period int, quota int, keyPrefix string, behindProxy bool) *RateLimitMiddleware {
  27. limiter := limit.NewPeriodLimit(period, quota, rds, keyPrefix)
  28. return &RateLimitMiddleware{limiter: limiter, behindProxy: behindProxy}
  29. }
  30. func (m *RateLimitMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
  31. return func(w http.ResponseWriter, r *http.Request) {
  32. ip := ExtractClientIP(r, m.behindProxy)
  33. key := fmt.Sprintf("ip:%s", ip)
  34. code, _ := m.limiter.Take(key)
  35. if code == limit.OverQuota {
  36. httpx.ErrorCtx(r.Context(), w, response.ErrTooManyRequests("请求过于频繁,请稍后再试"))
  37. return
  38. }
  39. ctx := WithClientIP(r.Context(), ip)
  40. next(w, r.WithContext(ctx))
  41. }
  42. }
  43. // ExtractClientIP 从请求中解析出客户端真实 IP。
  44. // 当 behindProxy=true 时按常规反向代理优先级解析:X-Forwarded-For 首段 → X-Real-IP → RemoteAddr;
  45. // 所有候选值都会经 net.ParseIP 校验合法性,非法或空时自动 fallthrough 到下一个来源,
  46. // 最终仍拿不到合法 IP 时打印 warn 日志并回落到 RemoteAddr 的原始字符串(方便运维排查代理链漏配)。
  47. // 当 behindProxy=false 时只采用 RemoteAddr,忽略任何请求头,防止客户端伪造(见审计 M-6)。
  48. func ExtractClientIP(r *http.Request, behindProxy bool) string {
  49. if behindProxy {
  50. if ip := firstValidIP(r.Header.Get("X-Forwarded-For")); ip != "" {
  51. return ip
  52. }
  53. if ip := firstValidIP(r.Header.Get("X-Real-IP")); ip != "" {
  54. return ip
  55. }
  56. logx.WithContext(r.Context()).Errorf("ExtractClientIP: behindProxy=true but no valid X-Forwarded-For / X-Real-IP header, falling back to RemoteAddr=%s; please check your reverse proxy configuration", r.RemoteAddr)
  57. }
  58. host, _, err := net.SplitHostPort(r.RemoteAddr)
  59. if err != nil {
  60. return r.RemoteAddr
  61. }
  62. return host
  63. }
  64. // firstValidIP 解析一个可能包含逗号分隔列表(X-Forwarded-For 的典型格式)的 IP 头,返回第一个
  65. // 可被 net.ParseIP 解析成功的地址;不合法或空值全部跳过,避免攻击者通过 "0.0.0.0, ..." 污染 key。
  66. func firstValidIP(headerVal string) string {
  67. if headerVal == "" {
  68. return ""
  69. }
  70. for _, part := range strings.Split(headerVal, ",") {
  71. candidate := strings.TrimSpace(part)
  72. if candidate == "" {
  73. continue
  74. }
  75. if net.ParseIP(candidate) != nil {
  76. return candidate
  77. }
  78. }
  79. return ""
  80. }