| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200 |
- package middleware
- import (
- "encoding/json"
- "fmt"
- "math/rand"
- "net/http"
- "net/http/httptest"
- "testing"
- "time"
- "perms-system-server/internal/response"
- "perms-system-server/internal/testutil"
- "github.com/zeromicro/go-zero/core/stores/redis"
- "github.com/zeromicro/go-zero/rest/httpx"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
- )
- func init() {
- response.Setup()
- }
- func uniqueIP() string {
- return fmt.Sprintf("10.%d.%d.%d", rand.Intn(256), rand.Intn(256), rand.Intn(256))
- }
- func newTestRedis() *redis.Redis {
- cfg := testutil.GetTestConfig()
- return redis.MustNewRedis(cfg.CacheRedis.Nodes[0].RedisConf)
- }
- func newTestMiddleware(rds *redis.Redis, quota int) *RateLimitMiddleware {
- prefix := fmt.Sprintf("test_rl_%d_%d", time.Now().UnixNano(), rand.Intn(100000))
- return NewRateLimitMiddleware(rds, 60, quota, prefix)
- }
- // TC-0525: 正常请求(未超限)
- func TestRateLimit_NormalRequest(t *testing.T) {
- rds := newTestRedis()
- m := newTestMiddleware(rds, 10)
- nextCalled := false
- handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
- nextCalled = true
- w.WriteHeader(http.StatusOK)
- })
- req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req.Header.Set("X-Forwarded-For", uniqueIP())
- w := httptest.NewRecorder()
- handler(w, req)
- assert.True(t, nextCalled, "next handler should be called")
- assert.Equal(t, http.StatusOK, w.Code)
- }
- // TC-0526: 超限请求被拒绝
- func TestRateLimit_OverQuotaRejected(t *testing.T) {
- rds := newTestRedis()
- m := newTestMiddleware(rds, 2)
- ip := uniqueIP()
- handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
- httpx.OkJson(w, nil)
- })
- for i := 0; i < 2; i++ {
- req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req.Header.Set("X-Forwarded-For", ip)
- w := httptest.NewRecorder()
- handler(w, req)
- }
- req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req.Header.Set("X-Forwarded-For", ip)
- w := httptest.NewRecorder()
- handler(w, req)
- var body response.Body
- err := json.Unmarshal(w.Body.Bytes(), &body)
- require.NoError(t, err)
- assert.Equal(t, 429, body.Code)
- assert.Equal(t, "请求过于频繁,请稍后再试", body.Msg)
- }
- // TC-0527: X-Forwarded-For被忽略(M-1安全修复验证)
- func TestRateLimit_XForwardedForIgnored(t *testing.T) {
- rds := newTestRedis()
- m := newTestMiddleware(rds, 1)
- var nextCount int
- handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
- nextCount++
- w.WriteHeader(http.StatusOK)
- })
- remoteAddr := uniqueIP() + ":12345"
- req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req.RemoteAddr = remoteAddr
- req.Header.Set("X-Forwarded-For", uniqueIP())
- handler(httptest.NewRecorder(), req)
- assert.Equal(t, 1, nextCount)
- req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req2.RemoteAddr = remoteAddr
- req2.Header.Set("X-Forwarded-For", uniqueIP())
- handler(httptest.NewRecorder(), req2)
- assert.Equal(t, 1, nextCount, "different X-Forwarded-For should NOT bypass rate limit; RemoteAddr is used")
- }
- // TC-0528: X-Real-IP被忽略(M-1安全修复验证)
- func TestRateLimit_XRealIPIgnored(t *testing.T) {
- rds := newTestRedis()
- m := newTestMiddleware(rds, 1)
- var nextCount int
- handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
- nextCount++
- w.WriteHeader(http.StatusOK)
- })
- remoteAddr := uniqueIP() + ":12345"
- req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req.RemoteAddr = remoteAddr
- req.Header.Set("X-Real-IP", uniqueIP())
- handler(httptest.NewRecorder(), req)
- assert.Equal(t, 1, nextCount)
- req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req2.RemoteAddr = remoteAddr
- req2.Header.Set("X-Real-IP", uniqueIP())
- handler(httptest.NewRecorder(), req2)
- assert.Equal(t, 1, nextCount, "different X-Real-IP should NOT bypass rate limit; RemoteAddr is used")
- }
- // TC-0529: IP从RemoteAddr获取
- func TestRateLimit_IPFromRemoteAddr(t *testing.T) {
- rds := newTestRedis()
- m := newTestMiddleware(rds, 1)
- ip := uniqueIP()
- remoteAddr := ip + ":12345"
- var gotNext bool
- handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
- gotNext = true
- w.WriteHeader(http.StatusOK)
- })
- req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req.RemoteAddr = remoteAddr
- w := httptest.NewRecorder()
- handler(w, req)
- assert.True(t, gotNext)
- gotNext = false
- req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req2.RemoteAddr = remoteAddr
- w2 := httptest.NewRecorder()
- handler(w2, req2)
- assert.False(t, gotNext, "should be rate limited by RemoteAddr")
- }
- // TC-0530: 不同RemoteAddr独立限流
- func TestRateLimit_DifferentIPsIndependent(t *testing.T) {
- rds := newTestRedis()
- m := newTestMiddleware(rds, 1)
- addr1 := uniqueIP() + ":12345"
- addr2 := uniqueIP() + ":12345"
- var nextCount int
- handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
- nextCount++
- w.WriteHeader(http.StatusOK)
- })
- req1 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req1.RemoteAddr = addr1
- handler(httptest.NewRecorder(), req1)
- assert.Equal(t, 1, nextCount)
- req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req2.RemoteAddr = addr2
- handler(httptest.NewRecorder(), req2)
- assert.Equal(t, 2, nextCount, "different RemoteAddr should have independent quotas")
- req3 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req3.RemoteAddr = addr1
- handler(httptest.NewRecorder(), req3)
- assert.Equal(t, 2, nextCount, "addr1 should be over quota")
- req4 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req4.RemoteAddr = addr2
- handler(httptest.NewRecorder(), req4)
- assert.Equal(t, 2, nextCount, "addr2 should be over quota")
- }
|