| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465 |
- package middleware
- import (
- "context"
- "fmt"
- "net"
- "net/http"
- "perms-system-server/internal/response"
- "github.com/zeromicro/go-zero/core/limit"
- "github.com/zeromicro/go-zero/core/stores/redis"
- "github.com/zeromicro/go-zero/rest/httpx"
- )
- const ctxKeyClientIP contextKey = "clientIP"
- func WithClientIP(ctx context.Context, ip string) context.Context {
- return context.WithValue(ctx, ctxKeyClientIP, ip)
- }
- func GetClientIP(ctx context.Context) string {
- ip, _ := ctx.Value(ctxKeyClientIP).(string)
- return ip
- }
- type RateLimitMiddleware struct {
- limiter *limit.PeriodLimit
- behindProxy bool
- }
- func NewRateLimitMiddleware(rds *redis.Redis, period int, quota int, keyPrefix string, behindProxy bool) *RateLimitMiddleware {
- limiter := limit.NewPeriodLimit(period, quota, rds, keyPrefix)
- return &RateLimitMiddleware{limiter: limiter, behindProxy: behindProxy}
- }
- func (m *RateLimitMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
- return func(w http.ResponseWriter, r *http.Request) {
- ip := ExtractClientIP(r, m.behindProxy)
- key := fmt.Sprintf("ip:%s", ip)
- code, _ := m.limiter.Take(key)
- if code == limit.OverQuota {
- httpx.ErrorCtx(r.Context(), w, response.ErrTooManyRequests("请求过于频繁,请稍后再试"))
- return
- }
- ctx := WithClientIP(r.Context(), ip)
- next(w, r.WithContext(ctx))
- }
- }
- // ExtractClientIP extracts client IP from the request.
- // When behindProxy is true, it trusts X-Real-IP header set by the reverse proxy.
- // When false, it only uses RemoteAddr for security.
- func ExtractClientIP(r *http.Request, behindProxy bool) string {
- if behindProxy {
- if ip := r.Header.Get("X-Real-IP"); ip != "" {
- return ip
- }
- }
- host, _, err := net.SplitHostPort(r.RemoteAddr)
- if err != nil {
- return r.RemoteAddr
- }
- return host
- }
|