package middleware import ( "encoding/json" "fmt" "math/rand" "net/http" "net/http/httptest" "testing" "time" "perms-system-server/internal/response" "perms-system-server/internal/testutil" "github.com/zeromicro/go-zero/core/stores/redis" "github.com/zeromicro/go-zero/rest/httpx" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func init() { response.Setup() } func uniqueIP() string { return fmt.Sprintf("10.%d.%d.%d", rand.Intn(256), rand.Intn(256), rand.Intn(256)) } func newTestRedis() *redis.Redis { cfg := testutil.GetTestConfig() return redis.MustNewRedis(cfg.CacheRedis.Nodes[0].RedisConf) } func newTestMiddleware(rds *redis.Redis, quota int) *RateLimitMiddleware { prefix := fmt.Sprintf("test_rl_%d_%d", time.Now().UnixNano(), rand.Intn(100000)) return NewRateLimitMiddleware(rds, 60, quota, prefix) } // TC-0525: 正常请求(未超限) func TestRateLimit_NormalRequest(t *testing.T) { rds := newTestRedis() m := newTestMiddleware(rds, 10) nextCalled := false handler := m.Handle(func(w http.ResponseWriter, r *http.Request) { nextCalled = true w.WriteHeader(http.StatusOK) }) req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil) req.Header.Set("X-Forwarded-For", uniqueIP()) w := httptest.NewRecorder() handler(w, req) assert.True(t, nextCalled, "next handler should be called") assert.Equal(t, http.StatusOK, w.Code) } // TC-0526: 超限请求被拒绝 func TestRateLimit_OverQuotaRejected(t *testing.T) { rds := newTestRedis() m := newTestMiddleware(rds, 2) ip := uniqueIP() handler := m.Handle(func(w http.ResponseWriter, r *http.Request) { httpx.OkJson(w, nil) }) for i := 0; i < 2; i++ { req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil) req.Header.Set("X-Forwarded-For", ip) w := httptest.NewRecorder() handler(w, req) } req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil) req.Header.Set("X-Forwarded-For", ip) w := httptest.NewRecorder() handler(w, req) var body response.Body err := json.Unmarshal(w.Body.Bytes(), &body) require.NoError(t, err) assert.Equal(t, 429, body.Code) assert.Equal(t, "请求过于频繁,请稍后再试", body.Msg) } // TC-0527: X-Forwarded-For被忽略(M-1安全修复验证) func TestRateLimit_XForwardedForIgnored(t *testing.T) { rds := newTestRedis() m := newTestMiddleware(rds, 1) var nextCount int handler := m.Handle(func(w http.ResponseWriter, r *http.Request) { nextCount++ w.WriteHeader(http.StatusOK) }) remoteAddr := uniqueIP() + ":12345" req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil) req.RemoteAddr = remoteAddr req.Header.Set("X-Forwarded-For", uniqueIP()) handler(httptest.NewRecorder(), req) assert.Equal(t, 1, nextCount) req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil) req2.RemoteAddr = remoteAddr req2.Header.Set("X-Forwarded-For", uniqueIP()) handler(httptest.NewRecorder(), req2) assert.Equal(t, 1, nextCount, "different X-Forwarded-For should NOT bypass rate limit; RemoteAddr is used") } // TC-0528: X-Real-IP被忽略(M-1安全修复验证) func TestRateLimit_XRealIPIgnored(t *testing.T) { rds := newTestRedis() m := newTestMiddleware(rds, 1) var nextCount int handler := m.Handle(func(w http.ResponseWriter, r *http.Request) { nextCount++ w.WriteHeader(http.StatusOK) }) remoteAddr := uniqueIP() + ":12345" req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil) req.RemoteAddr = remoteAddr req.Header.Set("X-Real-IP", uniqueIP()) handler(httptest.NewRecorder(), req) assert.Equal(t, 1, nextCount) req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil) req2.RemoteAddr = remoteAddr req2.Header.Set("X-Real-IP", uniqueIP()) handler(httptest.NewRecorder(), req2) assert.Equal(t, 1, nextCount, "different X-Real-IP should NOT bypass rate limit; RemoteAddr is used") } // TC-0529: IP从RemoteAddr获取 func TestRateLimit_IPFromRemoteAddr(t *testing.T) { rds := newTestRedis() m := newTestMiddleware(rds, 1) ip := uniqueIP() remoteAddr := ip + ":12345" var gotNext bool handler := m.Handle(func(w http.ResponseWriter, r *http.Request) { gotNext = true w.WriteHeader(http.StatusOK) }) req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil) req.RemoteAddr = remoteAddr w := httptest.NewRecorder() handler(w, req) assert.True(t, gotNext) gotNext = false req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil) req2.RemoteAddr = remoteAddr w2 := httptest.NewRecorder() handler(w2, req2) assert.False(t, gotNext, "should be rate limited by RemoteAddr") } // TC-0530: 不同RemoteAddr独立限流 func TestRateLimit_DifferentIPsIndependent(t *testing.T) { rds := newTestRedis() m := newTestMiddleware(rds, 1) addr1 := uniqueIP() + ":12345" addr2 := uniqueIP() + ":12345" var nextCount int handler := m.Handle(func(w http.ResponseWriter, r *http.Request) { nextCount++ w.WriteHeader(http.StatusOK) }) req1 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil) req1.RemoteAddr = addr1 handler(httptest.NewRecorder(), req1) assert.Equal(t, 1, nextCount) req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil) req2.RemoteAddr = addr2 handler(httptest.NewRecorder(), req2) assert.Equal(t, 2, nextCount, "different RemoteAddr should have independent quotas") req3 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil) req3.RemoteAddr = addr1 handler(httptest.NewRecorder(), req3) assert.Equal(t, 2, nextCount, "addr1 should be over quota") req4 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil) req4.RemoteAddr = addr2 handler(httptest.NewRecorder(), req4) assert.Equal(t, 2, nextCount, "addr2 should be over quota") }