ratelimitMiddleware_test.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. package middleware
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "math/rand"
  6. "net/http"
  7. "net/http/httptest"
  8. "testing"
  9. "time"
  10. "perms-system-server/internal/response"
  11. "perms-system-server/internal/testutil"
  12. "github.com/zeromicro/go-zero/core/stores/redis"
  13. "github.com/zeromicro/go-zero/rest/httpx"
  14. "github.com/stretchr/testify/assert"
  15. "github.com/stretchr/testify/require"
  16. )
  17. func init() {
  18. response.Setup()
  19. }
  20. func uniqueIP() string {
  21. return fmt.Sprintf("10.%d.%d.%d", rand.Intn(256), rand.Intn(256), rand.Intn(256))
  22. }
  23. func newTestRedis() *redis.Redis {
  24. cfg := testutil.GetTestConfig()
  25. return redis.MustNewRedis(cfg.CacheRedis.Nodes[0].RedisConf)
  26. }
  27. func newTestMiddleware(rds *redis.Redis, quota int) *RateLimitMiddleware {
  28. prefix := fmt.Sprintf("test_rl_%d_%d", time.Now().UnixNano(), rand.Intn(100000))
  29. return NewRateLimitMiddleware(rds, 60, quota, prefix)
  30. }
  31. // TC-0536: 正常请求(未超限)
  32. func TestRateLimit_NormalRequest(t *testing.T) {
  33. rds := newTestRedis()
  34. m := newTestMiddleware(rds, 10)
  35. nextCalled := false
  36. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  37. nextCalled = true
  38. w.WriteHeader(http.StatusOK)
  39. })
  40. req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  41. req.Header.Set("X-Forwarded-For", uniqueIP())
  42. w := httptest.NewRecorder()
  43. handler(w, req)
  44. assert.True(t, nextCalled, "next handler should be called")
  45. assert.Equal(t, http.StatusOK, w.Code)
  46. }
  47. // TC-0537: 超限请求被拒绝
  48. func TestRateLimit_OverQuotaRejected(t *testing.T) {
  49. rds := newTestRedis()
  50. m := newTestMiddleware(rds, 2)
  51. ip := uniqueIP()
  52. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  53. httpx.OkJson(w, nil)
  54. })
  55. for i := 0; i < 2; i++ {
  56. req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  57. req.Header.Set("X-Forwarded-For", ip)
  58. w := httptest.NewRecorder()
  59. handler(w, req)
  60. }
  61. req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  62. req.Header.Set("X-Forwarded-For", ip)
  63. w := httptest.NewRecorder()
  64. handler(w, req)
  65. var body response.Body
  66. err := json.Unmarshal(w.Body.Bytes(), &body)
  67. require.NoError(t, err)
  68. assert.Equal(t, 429, body.Code)
  69. assert.Equal(t, "请求过于频繁,请稍后再试", body.Msg)
  70. }
  71. // TC-0538: IP从X-Forwarded-For获取
  72. func TestRateLimit_IPFromXForwardedFor(t *testing.T) {
  73. rds := newTestRedis()
  74. m := newTestMiddleware(rds, 1)
  75. ip := uniqueIP()
  76. var gotNext bool
  77. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  78. gotNext = true
  79. w.WriteHeader(http.StatusOK)
  80. })
  81. req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  82. req.Header.Set("X-Forwarded-For", ip)
  83. req.Header.Set("X-Real-IP", uniqueIP())
  84. w := httptest.NewRecorder()
  85. handler(w, req)
  86. assert.True(t, gotNext)
  87. gotNext = false
  88. req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  89. req2.Header.Set("X-Forwarded-For", ip)
  90. req2.Header.Set("X-Real-IP", uniqueIP())
  91. w2 := httptest.NewRecorder()
  92. handler(w2, req2)
  93. assert.False(t, gotNext, "should be rate limited by X-Forwarded-For IP")
  94. }
  95. // TC-0539: IP从X-Real-IP获取
  96. func TestRateLimit_IPFromXRealIP(t *testing.T) {
  97. rds := newTestRedis()
  98. m := newTestMiddleware(rds, 1)
  99. ip := uniqueIP()
  100. var gotNext bool
  101. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  102. gotNext = true
  103. w.WriteHeader(http.StatusOK)
  104. })
  105. req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  106. req.Header.Set("X-Real-IP", ip)
  107. w := httptest.NewRecorder()
  108. handler(w, req)
  109. assert.True(t, gotNext)
  110. gotNext = false
  111. req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  112. req2.Header.Set("X-Real-IP", ip)
  113. w2 := httptest.NewRecorder()
  114. handler(w2, req2)
  115. assert.False(t, gotNext, "should be rate limited by X-Real-IP")
  116. }
  117. // TC-0540: IP从RemoteAddr获取
  118. func TestRateLimit_IPFromRemoteAddr(t *testing.T) {
  119. rds := newTestRedis()
  120. m := newTestMiddleware(rds, 1)
  121. ip := uniqueIP()
  122. remoteAddr := ip + ":12345"
  123. var gotNext bool
  124. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  125. gotNext = true
  126. w.WriteHeader(http.StatusOK)
  127. })
  128. req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  129. req.RemoteAddr = remoteAddr
  130. w := httptest.NewRecorder()
  131. handler(w, req)
  132. assert.True(t, gotNext)
  133. gotNext = false
  134. req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  135. req2.RemoteAddr = remoteAddr
  136. w2 := httptest.NewRecorder()
  137. handler(w2, req2)
  138. assert.False(t, gotNext, "should be rate limited by RemoteAddr")
  139. }
  140. // TC-0541: 不同IP独立限流
  141. func TestRateLimit_DifferentIPsIndependent(t *testing.T) {
  142. rds := newTestRedis()
  143. m := newTestMiddleware(rds, 1)
  144. ip1 := uniqueIP()
  145. ip2 := uniqueIP()
  146. var nextCount int
  147. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  148. nextCount++
  149. w.WriteHeader(http.StatusOK)
  150. })
  151. req1 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  152. req1.Header.Set("X-Forwarded-For", ip1)
  153. handler(httptest.NewRecorder(), req1)
  154. assert.Equal(t, 1, nextCount)
  155. req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  156. req2.Header.Set("X-Forwarded-For", ip2)
  157. handler(httptest.NewRecorder(), req2)
  158. assert.Equal(t, 2, nextCount, "different IPs should have independent quotas")
  159. req3 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  160. req3.Header.Set("X-Forwarded-For", ip1)
  161. handler(httptest.NewRecorder(), req3)
  162. assert.Equal(t, 2, nextCount, "ip1 should be over quota")
  163. req4 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  164. req4.Header.Set("X-Forwarded-For", ip2)
  165. handler(httptest.NewRecorder(), req4)
  166. assert.Equal(t, 2, nextCount, "ip2 should be over quota")
  167. }