package middleware import ( "context" "fmt" "net" "net/http" "strings" "perms-system-server/internal/response" "github.com/zeromicro/go-zero/core/limit" "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/stores/redis" "github.com/zeromicro/go-zero/rest/httpx" ) const ctxKeyClientIP contextKey = "clientIP" func WithClientIP(ctx context.Context, ip string) context.Context { return context.WithValue(ctx, ctxKeyClientIP, ip) } func GetClientIP(ctx context.Context) string { ip, _ := ctx.Value(ctxKeyClientIP).(string) return ip } type RateLimitMiddleware struct { limiter *limit.PeriodLimit behindProxy bool } func NewRateLimitMiddleware(rds *redis.Redis, period int, quota int, keyPrefix string, behindProxy bool) *RateLimitMiddleware { limiter := limit.NewPeriodLimit(period, quota, rds, keyPrefix) return &RateLimitMiddleware{limiter: limiter, behindProxy: behindProxy} } func (m *RateLimitMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ip := ExtractClientIP(r, m.behindProxy) key := fmt.Sprintf("ip:%s", ip) code, _ := m.limiter.Take(key) if code == limit.OverQuota { httpx.ErrorCtx(r.Context(), w, response.ErrTooManyRequests("请求过于频繁,请稍后再试")) return } ctx := WithClientIP(r.Context(), ip) next(w, r.WithContext(ctx)) } } // ExtractClientIP 从请求中解析出客户端真实 IP。 // 当 behindProxy=true 时按常规反向代理优先级解析:X-Forwarded-For 首段 → X-Real-IP → RemoteAddr; // 所有候选值都会经 net.ParseIP 校验合法性,非法或空时自动 fallthrough 到下一个来源, // 最终仍拿不到合法 IP 时打印 warn 日志并回落到 RemoteAddr 的原始字符串(方便运维排查代理链漏配)。 // 当 behindProxy=false 时只采用 RemoteAddr,忽略任何请求头,防止客户端伪造(见审计 M-6)。 func ExtractClientIP(r *http.Request, behindProxy bool) string { if behindProxy { if ip := firstValidIP(r.Header.Get("X-Forwarded-For")); ip != "" { return ip } if ip := firstValidIP(r.Header.Get("X-Real-IP")); ip != "" { return ip } logx.WithContext(r.Context()).Errorf("ExtractClientIP: behindProxy=true but no valid X-Forwarded-For / X-Real-IP header, falling back to RemoteAddr=%s; please check your reverse proxy configuration", r.RemoteAddr) } host, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { return r.RemoteAddr } return host } // firstValidIP 解析一个可能包含逗号分隔列表(X-Forwarded-For 的典型格式)的 IP 头,返回第一个 // 可被 net.ParseIP 解析成功的地址;不合法或空值全部跳过,避免攻击者通过 "0.0.0.0, ..." 污染 key。 func firstValidIP(headerVal string) string { if headerVal == "" { return "" } for _, part := range strings.Split(headerVal, ",") { candidate := strings.TrimSpace(part) if candidate == "" { continue } if net.ParseIP(candidate) != nil { return candidate } } return "" }