| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091 |
- package middleware
- import (
- "context"
- "fmt"
- "net"
- "net/http"
- "strings"
- "perms-system-server/internal/response"
- "github.com/zeromicro/go-zero/core/limit"
- "github.com/zeromicro/go-zero/core/logx"
- "github.com/zeromicro/go-zero/core/stores/redis"
- "github.com/zeromicro/go-zero/rest/httpx"
- )
- const ctxKeyClientIP contextKey = "clientIP"
- func WithClientIP(ctx context.Context, ip string) context.Context {
- return context.WithValue(ctx, ctxKeyClientIP, ip)
- }
- func GetClientIP(ctx context.Context) string {
- ip, _ := ctx.Value(ctxKeyClientIP).(string)
- return ip
- }
- type RateLimitMiddleware struct {
- limiter *limit.PeriodLimit
- behindProxy bool
- }
- func NewRateLimitMiddleware(rds *redis.Redis, period int, quota int, keyPrefix string, behindProxy bool) *RateLimitMiddleware {
- limiter := limit.NewPeriodLimit(period, quota, rds, keyPrefix)
- return &RateLimitMiddleware{limiter: limiter, behindProxy: behindProxy}
- }
- func (m *RateLimitMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
- return func(w http.ResponseWriter, r *http.Request) {
- ip := ExtractClientIP(r, m.behindProxy)
- key := fmt.Sprintf("ip:%s", ip)
- code, _ := m.limiter.Take(key)
- if code == limit.OverQuota {
- httpx.ErrorCtx(r.Context(), w, response.ErrTooManyRequests("请求过于频繁,请稍后再试"))
- return
- }
- ctx := WithClientIP(r.Context(), ip)
- next(w, r.WithContext(ctx))
- }
- }
- // ExtractClientIP 从请求中解析出客户端真实 IP。
- // 当 behindProxy=true 时按常规反向代理优先级解析:X-Forwarded-For 首段 → X-Real-IP → RemoteAddr;
- // 所有候选值都会经 net.ParseIP 校验合法性,非法或空时自动 fallthrough 到下一个来源,
- // 最终仍拿不到合法 IP 时打印 warn 日志并回落到 RemoteAddr 的原始字符串(方便运维排查代理链漏配)。
- // 当 behindProxy=false 时只采用 RemoteAddr,忽略任何请求头,防止客户端伪造(见审计 M-6)。
- func ExtractClientIP(r *http.Request, behindProxy bool) string {
- if behindProxy {
- if ip := firstValidIP(r.Header.Get("X-Forwarded-For")); ip != "" {
- return ip
- }
- if ip := firstValidIP(r.Header.Get("X-Real-IP")); ip != "" {
- return ip
- }
- logx.WithContext(r.Context()).Errorf("ExtractClientIP: behindProxy=true but no valid X-Forwarded-For / X-Real-IP header, falling back to RemoteAddr=%s; please check your reverse proxy configuration", r.RemoteAddr)
- }
- host, _, err := net.SplitHostPort(r.RemoteAddr)
- if err != nil {
- return r.RemoteAddr
- }
- return host
- }
- // firstValidIP 解析一个可能包含逗号分隔列表(X-Forwarded-For 的典型格式)的 IP 头,返回第一个
- // 可被 net.ParseIP 解析成功的地址;不合法或空值全部跳过,避免攻击者通过 "0.0.0.0, ..." 污染 key。
- func firstValidIP(headerVal string) string {
- if headerVal == "" {
- return ""
- }
- for _, part := range strings.Split(headerVal, ",") {
- candidate := strings.TrimSpace(part)
- if candidate == "" {
- continue
- }
- if net.ParseIP(candidate) != nil {
- return candidate
- }
- }
- return ""
- }
|