| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344 |
- package middleware
- import (
- "encoding/json"
- "fmt"
- "math/rand"
- "net/http"
- "net/http/httptest"
- "testing"
- "time"
- "perms-system-server/internal/response"
- "perms-system-server/internal/testutil"
- "github.com/zeromicro/go-zero/core/stores/redis"
- "github.com/zeromicro/go-zero/rest/httpx"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
- )
- func init() {
- response.Setup()
- }
- func uniqueIP() string {
- return fmt.Sprintf("10.%d.%d.%d", rand.Intn(256), rand.Intn(256), rand.Intn(256))
- }
- func newTestRedis() *redis.Redis {
- cfg := testutil.GetTestConfig()
- return redis.MustNewRedis(cfg.CacheRedis.Nodes[0].RedisConf)
- }
- func newTestMiddleware(rds *redis.Redis, quota int) *RateLimitMiddleware {
- prefix := fmt.Sprintf("test_rl_%d_%d", time.Now().UnixNano(), rand.Intn(100000))
- return NewRateLimitMiddleware(rds, 60, quota, prefix, false)
- }
- // TC-0546: 正常请求(未超限)
- func TestRateLimit_NormalRequest(t *testing.T) {
- rds := newTestRedis()
- m := newTestMiddleware(rds, 10)
- nextCalled := false
- handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
- nextCalled = true
- w.WriteHeader(http.StatusOK)
- })
- req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req.Header.Set("X-Forwarded-For", uniqueIP())
- w := httptest.NewRecorder()
- handler(w, req)
- assert.True(t, nextCalled, "next handler should be called")
- assert.Equal(t, http.StatusOK, w.Code)
- }
- // TC-0547: 超限请求被拒绝
- func TestRateLimit_OverQuotaRejected(t *testing.T) {
- rds := newTestRedis()
- m := newTestMiddleware(rds, 2)
- ip := uniqueIP()
- handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
- httpx.OkJson(w, nil)
- })
- for i := 0; i < 2; i++ {
- req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req.Header.Set("X-Forwarded-For", ip)
- w := httptest.NewRecorder()
- handler(w, req)
- }
- req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req.Header.Set("X-Forwarded-For", ip)
- w := httptest.NewRecorder()
- handler(w, req)
- var body response.Body
- err := json.Unmarshal(w.Body.Bytes(), &body)
- require.NoError(t, err)
- assert.Equal(t, 429, body.Code)
- assert.Equal(t, "请求过于频繁,请稍后再试", body.Msg)
- }
- // TC-0548: behindProxy=false时XFF被忽略
- func TestRateLimit_XForwardedForIgnored(t *testing.T) {
- rds := newTestRedis()
- m := newTestMiddleware(rds, 1)
- var nextCount int
- handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
- nextCount++
- w.WriteHeader(http.StatusOK)
- })
- remoteAddr := uniqueIP() + ":12345"
- req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req.RemoteAddr = remoteAddr
- req.Header.Set("X-Forwarded-For", uniqueIP())
- handler(httptest.NewRecorder(), req)
- assert.Equal(t, 1, nextCount)
- req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req2.RemoteAddr = remoteAddr
- req2.Header.Set("X-Forwarded-For", uniqueIP())
- handler(httptest.NewRecorder(), req2)
- assert.Equal(t, 1, nextCount, "different X-Forwarded-For should NOT bypass rate limit; RemoteAddr is used")
- }
- // TC-0549: behindProxy=false时X-Real-IP被忽略
- func TestRateLimit_XRealIPIgnored(t *testing.T) {
- rds := newTestRedis()
- m := newTestMiddleware(rds, 1)
- var nextCount int
- handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
- nextCount++
- w.WriteHeader(http.StatusOK)
- })
- remoteAddr := uniqueIP() + ":12345"
- req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req.RemoteAddr = remoteAddr
- req.Header.Set("X-Real-IP", uniqueIP())
- handler(httptest.NewRecorder(), req)
- assert.Equal(t, 1, nextCount)
- req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req2.RemoteAddr = remoteAddr
- req2.Header.Set("X-Real-IP", uniqueIP())
- handler(httptest.NewRecorder(), req2)
- assert.Equal(t, 1, nextCount, "different X-Real-IP should NOT bypass rate limit; RemoteAddr is used")
- }
- // TC-0550: IP从RemoteAddr解析
- func TestRateLimit_IPFromRemoteAddr(t *testing.T) {
- rds := newTestRedis()
- m := newTestMiddleware(rds, 1)
- ip := uniqueIP()
- remoteAddr := ip + ":12345"
- var gotNext bool
- handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
- gotNext = true
- w.WriteHeader(http.StatusOK)
- })
- req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req.RemoteAddr = remoteAddr
- w := httptest.NewRecorder()
- handler(w, req)
- assert.True(t, gotNext)
- gotNext = false
- req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req2.RemoteAddr = remoteAddr
- w2 := httptest.NewRecorder()
- handler(w2, req2)
- assert.False(t, gotNext, "should be rate limited by RemoteAddr")
- }
- // TC-0551: 不同RemoteAddr独立限流
- func TestRateLimit_DifferentIPsIndependent(t *testing.T) {
- rds := newTestRedis()
- m := newTestMiddleware(rds, 1)
- addr1 := uniqueIP() + ":12345"
- addr2 := uniqueIP() + ":12345"
- var nextCount int
- handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
- nextCount++
- w.WriteHeader(http.StatusOK)
- })
- req1 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req1.RemoteAddr = addr1
- handler(httptest.NewRecorder(), req1)
- assert.Equal(t, 1, nextCount)
- req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req2.RemoteAddr = addr2
- handler(httptest.NewRecorder(), req2)
- assert.Equal(t, 2, nextCount, "different RemoteAddr should have independent quotas")
- req3 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req3.RemoteAddr = addr1
- handler(httptest.NewRecorder(), req3)
- assert.Equal(t, 2, nextCount, "addr1 should be over quota")
- req4 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req4.RemoteAddr = addr2
- handler(httptest.NewRecorder(), req4)
- assert.Equal(t, 2, nextCount, "addr2 should be over quota")
- }
- func newTestMiddlewareProxy(rds *redis.Redis, quota int) *RateLimitMiddleware {
- prefix := fmt.Sprintf("test_rl_%d_%d", time.Now().UnixNano(), rand.Intn(100000))
- return NewRateLimitMiddleware(rds, 60, quota, prefix, true)
- }
- // TC-0552: behindProxy=true时信任X-Real-IP
- func TestRateLimit_BehindProxy_TrustsXRealIP(t *testing.T) {
- rds := newTestRedis()
- m := newTestMiddlewareProxy(rds, 1)
- remoteAddr := uniqueIP() + ":12345"
- var nextCount int
- handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
- nextCount++
- w.WriteHeader(http.StatusOK)
- })
- req1 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req1.RemoteAddr = remoteAddr
- req1.Header.Set("X-Real-IP", uniqueIP())
- handler(httptest.NewRecorder(), req1)
- assert.Equal(t, 1, nextCount)
- req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req2.RemoteAddr = remoteAddr
- req2.Header.Set("X-Real-IP", uniqueIP())
- handler(httptest.NewRecorder(), req2)
- assert.Equal(t, 2, nextCount, "different X-Real-IP should have independent quotas when behindProxy=true")
- }
- // TC-0553: behindProxy=true时无X-Real-IP回退RemoteAddr
- func TestRateLimit_BehindProxy_FallbackToRemoteAddr(t *testing.T) {
- rds := newTestRedis()
- m := newTestMiddlewareProxy(rds, 1)
- remoteAddr := uniqueIP() + ":12345"
- var nextCount int
- handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
- nextCount++
- w.WriteHeader(http.StatusOK)
- })
- req1 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req1.RemoteAddr = remoteAddr
- handler(httptest.NewRecorder(), req1)
- assert.Equal(t, 1, nextCount)
- req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req2.RemoteAddr = remoteAddr
- handler(httptest.NewRecorder(), req2)
- assert.Equal(t, 1, nextCount, "should fall back to RemoteAddr when X-Real-IP is absent")
- }
- // TC-0554: behindProxy=true时XFF仍被忽略
- func TestRateLimit_BehindProxy_XFFStillIgnored(t *testing.T) {
- rds := newTestRedis()
- m := newTestMiddlewareProxy(rds, 1)
- remoteAddr := uniqueIP() + ":12345"
- var nextCount int
- handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
- nextCount++
- w.WriteHeader(http.StatusOK)
- })
- req1 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req1.RemoteAddr = remoteAddr
- req1.Header.Set("X-Forwarded-For", uniqueIP())
- handler(httptest.NewRecorder(), req1)
- assert.Equal(t, 1, nextCount)
- req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req2.RemoteAddr = remoteAddr
- req2.Header.Set("X-Forwarded-For", uniqueIP())
- handler(httptest.NewRecorder(), req2)
- assert.Equal(t, 1, nextCount, "X-Forwarded-For should NOT bypass rate limit even with behindProxy=true")
- }
- // =============================================================================
- // audit L-2 回归:产品登录与管后登录必须使用独立的限流桶
- // 修复前:两个入口共享同一个 keyPrefix,导致攻击者对产品登录的爆破会消耗管后登录的配额(或反之)
- // 修复后:ProductLoginRateLimit 使用 "...:rl:login:product",AdminLoginRateLimit 使用 "...:rl:login:admin"
- // =============================================================================
- // TC-0710: 两个不同 keyPrefix 的限流中间件在同一 IP 上互不影响
- func TestRateLimit_ProductAndAdminBucketsAreIndependent(t *testing.T) {
- rds := newTestRedis()
- // 模拟 servicecontext.go 里的两个独立桶
- prefixBase := fmt.Sprintf("test_rl_l2_%d_%d", time.Now().UnixNano(), rand.Intn(100000))
- productM := NewRateLimitMiddleware(rds, 60, 1, prefixBase+":rl:login:product", false)
- adminM := NewRateLimitMiddleware(rds, 60, 1, prefixBase+":rl:login:admin", false)
- ip := uniqueIP()
- remoteAddr := ip + ":12345"
- var productNext, adminNext int
- productHandler := productM.Handle(func(w http.ResponseWriter, r *http.Request) {
- productNext++
- w.WriteHeader(http.StatusOK)
- })
- adminHandler := adminM.Handle(func(w http.ResponseWriter, r *http.Request) {
- adminNext++
- w.WriteHeader(http.StatusOK)
- })
- // 对产品登录打一枪(配额=1,刚好用完)
- req1 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req1.RemoteAddr = remoteAddr
- productHandler(httptest.NewRecorder(), req1)
- require.Equal(t, 1, productNext)
- // 再对产品登录打一枪 → 被限流
- req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
- req2.RemoteAddr = remoteAddr
- productHandler(httptest.NewRecorder(), req2)
- require.Equal(t, 1, productNext, "产品登录桶已耗尽")
- // 关键:同 IP 对管后登录仍应放行(独立桶)
- req3 := httptest.NewRequest(http.MethodPost, "/api/auth/adminLogin", nil)
- req3.RemoteAddr = remoteAddr
- adminHandler(httptest.NewRecorder(), req3)
- assert.Equal(t, 1, adminNext,
- "audit L-2: 产品登录限流不应影响管后登录(不同 keyPrefix)")
- // 再打管后一枪 → 管后桶也应耗尽,但产品桶已经耗尽在先
- req4 := httptest.NewRequest(http.MethodPost, "/api/auth/adminLogin", nil)
- req4.RemoteAddr = remoteAddr
- adminHandler(httptest.NewRecorder(), req4)
- assert.Equal(t, 1, adminNext, "管后桶配额=1,第二次应被限流")
- }
- // TC-0555: RemoteAddr无端口格式
- func TestExtractClientIP_RemoteAddrNoPort(t *testing.T) {
- req := httptest.NewRequest(http.MethodPost, "/api/test", nil)
- req.RemoteAddr = "1.2.3.4"
- ip := ExtractClientIP(req, false)
- assert.Equal(t, "1.2.3.4", ip, "should return raw RemoteAddr when SplitHostPort fails")
- ip2 := ExtractClientIP(req, true)
- assert.Equal(t, "1.2.3.4", ip2, "behindProxy=true without X-Real-IP should also fallback")
- }
|