package middleware import ( "encoding/json" "fmt" "math/rand" "net/http" "net/http/httptest" "testing" "time" "perms-system-server/internal/response" "perms-system-server/internal/testutil" "github.com/zeromicro/go-zero/core/stores/redis" "github.com/zeromicro/go-zero/rest/httpx" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func init() { response.Setup() } func uniqueIP() string { return fmt.Sprintf("10.%d.%d.%d", rand.Intn(256), rand.Intn(256), rand.Intn(256)) } func newTestRedis() *redis.Redis { cfg := testutil.GetTestConfig() return redis.MustNewRedis(cfg.CacheRedis.Nodes[0].RedisConf) } func newTestMiddleware(rds *redis.Redis, quota int) *RateLimitMiddleware { prefix := fmt.Sprintf("test_rl_%d_%d", time.Now().UnixNano(), rand.Intn(100000)) return NewRateLimitMiddleware(rds, 60, quota, prefix) } // TC-0536: 正常请求(未超限) func TestRateLimit_NormalRequest(t *testing.T) { rds := newTestRedis() m := newTestMiddleware(rds, 10) nextCalled := false handler := m.Handle(func(w http.ResponseWriter, r *http.Request) { nextCalled = true w.WriteHeader(http.StatusOK) }) req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil) req.Header.Set("X-Forwarded-For", uniqueIP()) w := httptest.NewRecorder() handler(w, req) assert.True(t, nextCalled, "next handler should be called") assert.Equal(t, http.StatusOK, w.Code) } // TC-0537: 超限请求被拒绝 func TestRateLimit_OverQuotaRejected(t *testing.T) { rds := newTestRedis() m := newTestMiddleware(rds, 2) ip := uniqueIP() handler := m.Handle(func(w http.ResponseWriter, r *http.Request) { httpx.OkJson(w, nil) }) for i := 0; i < 2; i++ { req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil) req.Header.Set("X-Forwarded-For", ip) w := httptest.NewRecorder() handler(w, req) } req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil) req.Header.Set("X-Forwarded-For", ip) w := httptest.NewRecorder() handler(w, req) var body response.Body err := json.Unmarshal(w.Body.Bytes(), &body) require.NoError(t, err) assert.Equal(t, 429, body.Code) assert.Equal(t, "请求过于频繁,请稍后再试", body.Msg) } // TC-0538: IP从X-Forwarded-For获取 func TestRateLimit_IPFromXForwardedFor(t *testing.T) { rds := newTestRedis() m := newTestMiddleware(rds, 1) ip := uniqueIP() var gotNext bool handler := m.Handle(func(w http.ResponseWriter, r *http.Request) { gotNext = true w.WriteHeader(http.StatusOK) }) req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil) req.Header.Set("X-Forwarded-For", ip) req.Header.Set("X-Real-IP", uniqueIP()) w := httptest.NewRecorder() handler(w, req) assert.True(t, gotNext) gotNext = false req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil) req2.Header.Set("X-Forwarded-For", ip) req2.Header.Set("X-Real-IP", uniqueIP()) w2 := httptest.NewRecorder() handler(w2, req2) assert.False(t, gotNext, "should be rate limited by X-Forwarded-For IP") } // TC-0539: IP从X-Real-IP获取 func TestRateLimit_IPFromXRealIP(t *testing.T) { rds := newTestRedis() m := newTestMiddleware(rds, 1) ip := uniqueIP() var gotNext bool handler := m.Handle(func(w http.ResponseWriter, r *http.Request) { gotNext = true w.WriteHeader(http.StatusOK) }) req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil) req.Header.Set("X-Real-IP", ip) w := httptest.NewRecorder() handler(w, req) assert.True(t, gotNext) gotNext = false req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil) req2.Header.Set("X-Real-IP", ip) w2 := httptest.NewRecorder() handler(w2, req2) assert.False(t, gotNext, "should be rate limited by X-Real-IP") } // TC-0540: IP从RemoteAddr获取 func TestRateLimit_IPFromRemoteAddr(t *testing.T) { rds := newTestRedis() m := newTestMiddleware(rds, 1) ip := uniqueIP() remoteAddr := ip + ":12345" var gotNext bool handler := m.Handle(func(w http.ResponseWriter, r *http.Request) { gotNext = true w.WriteHeader(http.StatusOK) }) req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil) req.RemoteAddr = remoteAddr w := httptest.NewRecorder() handler(w, req) assert.True(t, gotNext) gotNext = false req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil) req2.RemoteAddr = remoteAddr w2 := httptest.NewRecorder() handler(w2, req2) assert.False(t, gotNext, "should be rate limited by RemoteAddr") } // TC-0541: 不同IP独立限流 func TestRateLimit_DifferentIPsIndependent(t *testing.T) { rds := newTestRedis() m := newTestMiddleware(rds, 1) ip1 := uniqueIP() ip2 := uniqueIP() var nextCount int handler := m.Handle(func(w http.ResponseWriter, r *http.Request) { nextCount++ w.WriteHeader(http.StatusOK) }) req1 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil) req1.Header.Set("X-Forwarded-For", ip1) handler(httptest.NewRecorder(), req1) assert.Equal(t, 1, nextCount) req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil) req2.Header.Set("X-Forwarded-For", ip2) handler(httptest.NewRecorder(), req2) assert.Equal(t, 2, nextCount, "different IPs should have independent quotas") req3 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil) req3.Header.Set("X-Forwarded-For", ip1) handler(httptest.NewRecorder(), req3) assert.Equal(t, 2, nextCount, "ip1 should be over quota") req4 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil) req4.Header.Set("X-Forwarded-For", ip2) handler(httptest.NewRecorder(), req4) assert.Equal(t, 2, nextCount, "ip2 should be over quota") }