| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202 |
- package middleware
- import (
- "encoding/json"
- "fmt"
- "math/rand"
- "net/http"
- "net/http/httptest"
- "testing"
- "time"
- "perms-system-server/internal/response"
- "perms-system-server/internal/testutil"
- "github.com/zeromicro/go-zero/core/stores/redis"
- "github.com/zeromicro/go-zero/rest/httpx"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
- )
- func init() {
- response.Setup()
- }
- func uniqueIP() string {
- return fmt.Sprintf("10.%d.%d.%d", rand.Intn(256), rand.Intn(256), rand.Intn(256))
- }
- func newTestRedis() *redis.Redis {
- cfg := testutil.GetTestConfig()
- return redis.MustNewRedis(cfg.CacheRedis.Nodes[0].RedisConf)
- }
- func newTestMiddleware(rds *redis.Redis, quota int) *RateLimitMiddleware {
- prefix := fmt.Sprintf("test_rl_%d_%d", time.Now().UnixNano(), rand.Intn(100000))
- return NewRateLimitMiddleware(rds, 60, quota, prefix)
- }
- // TC-0536: 正常请求(未超限)
- func TestRateLimit_NormalRequest(t *testing.T) {
- rds := newTestRedis()
- m := newTestMiddleware(rds, 10)
- nextCalled := false
- handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
- nextCalled = true
- w.WriteHeader(http.StatusOK)
- })
- req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req.Header.Set("X-Forwarded-For", uniqueIP())
- w := httptest.NewRecorder()
- handler(w, req)
- assert.True(t, nextCalled, "next handler should be called")
- assert.Equal(t, http.StatusOK, w.Code)
- }
- // TC-0537: 超限请求被拒绝
- func TestRateLimit_OverQuotaRejected(t *testing.T) {
- rds := newTestRedis()
- m := newTestMiddleware(rds, 2)
- ip := uniqueIP()
- handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
- httpx.OkJson(w, nil)
- })
- for i := 0; i < 2; i++ {
- req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req.Header.Set("X-Forwarded-For", ip)
- w := httptest.NewRecorder()
- handler(w, req)
- }
- req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req.Header.Set("X-Forwarded-For", ip)
- w := httptest.NewRecorder()
- handler(w, req)
- var body response.Body
- err := json.Unmarshal(w.Body.Bytes(), &body)
- require.NoError(t, err)
- assert.Equal(t, 429, body.Code)
- assert.Equal(t, "请求过于频繁,请稍后再试", body.Msg)
- }
- // TC-0538: IP从X-Forwarded-For获取
- func TestRateLimit_IPFromXForwardedFor(t *testing.T) {
- rds := newTestRedis()
- m := newTestMiddleware(rds, 1)
- ip := uniqueIP()
- var gotNext bool
- handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
- gotNext = true
- w.WriteHeader(http.StatusOK)
- })
- req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req.Header.Set("X-Forwarded-For", ip)
- req.Header.Set("X-Real-IP", uniqueIP())
- w := httptest.NewRecorder()
- handler(w, req)
- assert.True(t, gotNext)
- gotNext = false
- req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req2.Header.Set("X-Forwarded-For", ip)
- req2.Header.Set("X-Real-IP", uniqueIP())
- w2 := httptest.NewRecorder()
- handler(w2, req2)
- assert.False(t, gotNext, "should be rate limited by X-Forwarded-For IP")
- }
- // TC-0539: IP从X-Real-IP获取
- func TestRateLimit_IPFromXRealIP(t *testing.T) {
- rds := newTestRedis()
- m := newTestMiddleware(rds, 1)
- ip := uniqueIP()
- var gotNext bool
- handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
- gotNext = true
- w.WriteHeader(http.StatusOK)
- })
- req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req.Header.Set("X-Real-IP", ip)
- w := httptest.NewRecorder()
- handler(w, req)
- assert.True(t, gotNext)
- gotNext = false
- req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req2.Header.Set("X-Real-IP", ip)
- w2 := httptest.NewRecorder()
- handler(w2, req2)
- assert.False(t, gotNext, "should be rate limited by X-Real-IP")
- }
- // TC-0540: IP从RemoteAddr获取
- func TestRateLimit_IPFromRemoteAddr(t *testing.T) {
- rds := newTestRedis()
- m := newTestMiddleware(rds, 1)
- ip := uniqueIP()
- remoteAddr := ip + ":12345"
- var gotNext bool
- handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
- gotNext = true
- w.WriteHeader(http.StatusOK)
- })
- req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req.RemoteAddr = remoteAddr
- w := httptest.NewRecorder()
- handler(w, req)
- assert.True(t, gotNext)
- gotNext = false
- req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req2.RemoteAddr = remoteAddr
- w2 := httptest.NewRecorder()
- handler(w2, req2)
- assert.False(t, gotNext, "should be rate limited by RemoteAddr")
- }
- // TC-0541: 不同IP独立限流
- func TestRateLimit_DifferentIPsIndependent(t *testing.T) {
- rds := newTestRedis()
- m := newTestMiddleware(rds, 1)
- ip1 := uniqueIP()
- ip2 := uniqueIP()
- var nextCount int
- handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
- nextCount++
- w.WriteHeader(http.StatusOK)
- })
- req1 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req1.Header.Set("X-Forwarded-For", ip1)
- handler(httptest.NewRecorder(), req1)
- assert.Equal(t, 1, nextCount)
- req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req2.Header.Set("X-Forwarded-For", ip2)
- handler(httptest.NewRecorder(), req2)
- assert.Equal(t, 2, nextCount, "different IPs should have independent quotas")
- req3 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req3.Header.Set("X-Forwarded-For", ip1)
- handler(httptest.NewRecorder(), req3)
- assert.Equal(t, 2, nextCount, "ip1 should be over quota")
- req4 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req4.Header.Set("X-Forwarded-For", ip2)
- handler(httptest.NewRecorder(), req4)
- assert.Equal(t, 2, nextCount, "ip2 should be over quota")
- }
|