ratelimitMiddleware_test.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. package middleware
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "math/rand"
  6. "net/http"
  7. "net/http/httptest"
  8. "testing"
  9. "time"
  10. "perms-system-server/internal/response"
  11. "perms-system-server/internal/testutil"
  12. "github.com/zeromicro/go-zero/core/stores/redis"
  13. "github.com/zeromicro/go-zero/rest/httpx"
  14. "github.com/stretchr/testify/assert"
  15. "github.com/stretchr/testify/require"
  16. )
  17. func init() {
  18. response.Setup()
  19. }
  20. func uniqueIP() string {
  21. return fmt.Sprintf("10.%d.%d.%d", rand.Intn(256), rand.Intn(256), rand.Intn(256))
  22. }
  23. func newTestRedis() *redis.Redis {
  24. cfg := testutil.GetTestConfig()
  25. return redis.MustNewRedis(cfg.CacheRedis.Nodes[0].RedisConf)
  26. }
  27. func newTestMiddleware(rds *redis.Redis, quota int) *RateLimitMiddleware {
  28. prefix := fmt.Sprintf("test_rl_%d_%d", time.Now().UnixNano(), rand.Intn(100000))
  29. return NewRateLimitMiddleware(rds, 60, quota, prefix)
  30. }
  31. // TC-0525: 正常请求(未超限)
  32. func TestRateLimit_NormalRequest(t *testing.T) {
  33. rds := newTestRedis()
  34. m := newTestMiddleware(rds, 10)
  35. nextCalled := false
  36. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  37. nextCalled = true
  38. w.WriteHeader(http.StatusOK)
  39. })
  40. req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  41. req.Header.Set("X-Forwarded-For", uniqueIP())
  42. w := httptest.NewRecorder()
  43. handler(w, req)
  44. assert.True(t, nextCalled, "next handler should be called")
  45. assert.Equal(t, http.StatusOK, w.Code)
  46. }
  47. // TC-0526: 超限请求被拒绝
  48. func TestRateLimit_OverQuotaRejected(t *testing.T) {
  49. rds := newTestRedis()
  50. m := newTestMiddleware(rds, 2)
  51. ip := uniqueIP()
  52. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  53. httpx.OkJson(w, nil)
  54. })
  55. for i := 0; i < 2; i++ {
  56. req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  57. req.Header.Set("X-Forwarded-For", ip)
  58. w := httptest.NewRecorder()
  59. handler(w, req)
  60. }
  61. req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  62. req.Header.Set("X-Forwarded-For", ip)
  63. w := httptest.NewRecorder()
  64. handler(w, req)
  65. var body response.Body
  66. err := json.Unmarshal(w.Body.Bytes(), &body)
  67. require.NoError(t, err)
  68. assert.Equal(t, 429, body.Code)
  69. assert.Equal(t, "请求过于频繁,请稍后再试", body.Msg)
  70. }
  71. // TC-0527: X-Forwarded-For被忽略(M-1安全修复验证)
  72. func TestRateLimit_XForwardedForIgnored(t *testing.T) {
  73. rds := newTestRedis()
  74. m := newTestMiddleware(rds, 1)
  75. var nextCount int
  76. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  77. nextCount++
  78. w.WriteHeader(http.StatusOK)
  79. })
  80. remoteAddr := uniqueIP() + ":12345"
  81. req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  82. req.RemoteAddr = remoteAddr
  83. req.Header.Set("X-Forwarded-For", uniqueIP())
  84. handler(httptest.NewRecorder(), req)
  85. assert.Equal(t, 1, nextCount)
  86. req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  87. req2.RemoteAddr = remoteAddr
  88. req2.Header.Set("X-Forwarded-For", uniqueIP())
  89. handler(httptest.NewRecorder(), req2)
  90. assert.Equal(t, 1, nextCount, "different X-Forwarded-For should NOT bypass rate limit; RemoteAddr is used")
  91. }
  92. // TC-0528: X-Real-IP被忽略(M-1安全修复验证)
  93. func TestRateLimit_XRealIPIgnored(t *testing.T) {
  94. rds := newTestRedis()
  95. m := newTestMiddleware(rds, 1)
  96. var nextCount int
  97. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  98. nextCount++
  99. w.WriteHeader(http.StatusOK)
  100. })
  101. remoteAddr := uniqueIP() + ":12345"
  102. req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  103. req.RemoteAddr = remoteAddr
  104. req.Header.Set("X-Real-IP", uniqueIP())
  105. handler(httptest.NewRecorder(), req)
  106. assert.Equal(t, 1, nextCount)
  107. req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  108. req2.RemoteAddr = remoteAddr
  109. req2.Header.Set("X-Real-IP", uniqueIP())
  110. handler(httptest.NewRecorder(), req2)
  111. assert.Equal(t, 1, nextCount, "different X-Real-IP should NOT bypass rate limit; RemoteAddr is used")
  112. }
  113. // TC-0529: IP从RemoteAddr获取
  114. func TestRateLimit_IPFromRemoteAddr(t *testing.T) {
  115. rds := newTestRedis()
  116. m := newTestMiddleware(rds, 1)
  117. ip := uniqueIP()
  118. remoteAddr := ip + ":12345"
  119. var gotNext bool
  120. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  121. gotNext = true
  122. w.WriteHeader(http.StatusOK)
  123. })
  124. req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  125. req.RemoteAddr = remoteAddr
  126. w := httptest.NewRecorder()
  127. handler(w, req)
  128. assert.True(t, gotNext)
  129. gotNext = false
  130. req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  131. req2.RemoteAddr = remoteAddr
  132. w2 := httptest.NewRecorder()
  133. handler(w2, req2)
  134. assert.False(t, gotNext, "should be rate limited by RemoteAddr")
  135. }
  136. // TC-0530: 不同RemoteAddr独立限流
  137. func TestRateLimit_DifferentIPsIndependent(t *testing.T) {
  138. rds := newTestRedis()
  139. m := newTestMiddleware(rds, 1)
  140. addr1 := uniqueIP() + ":12345"
  141. addr2 := uniqueIP() + ":12345"
  142. var nextCount int
  143. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  144. nextCount++
  145. w.WriteHeader(http.StatusOK)
  146. })
  147. req1 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  148. req1.RemoteAddr = addr1
  149. handler(httptest.NewRecorder(), req1)
  150. assert.Equal(t, 1, nextCount)
  151. req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  152. req2.RemoteAddr = addr2
  153. handler(httptest.NewRecorder(), req2)
  154. assert.Equal(t, 2, nextCount, "different RemoteAddr should have independent quotas")
  155. req3 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  156. req3.RemoteAddr = addr1
  157. handler(httptest.NewRecorder(), req3)
  158. assert.Equal(t, 2, nextCount, "addr1 should be over quota")
  159. req4 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  160. req4.RemoteAddr = addr2
  161. handler(httptest.NewRecorder(), req4)
  162. assert.Equal(t, 2, nextCount, "addr2 should be over quota")
  163. }