ratelimitMiddleware_test.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. package middleware
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "math/rand"
  6. "net/http"
  7. "net/http/httptest"
  8. "testing"
  9. "time"
  10. "perms-system-server/internal/response"
  11. "perms-system-server/internal/testutil"
  12. "github.com/zeromicro/go-zero/core/stores/redis"
  13. "github.com/zeromicro/go-zero/rest/httpx"
  14. "github.com/stretchr/testify/assert"
  15. "github.com/stretchr/testify/require"
  16. )
  17. func init() {
  18. response.Setup()
  19. }
  20. func uniqueIP() string {
  21. return fmt.Sprintf("10.%d.%d.%d", rand.Intn(256), rand.Intn(256), rand.Intn(256))
  22. }
  23. func newTestRedis() *redis.Redis {
  24. cfg := testutil.GetTestConfig()
  25. return redis.MustNewRedis(cfg.CacheRedis.Nodes[0].RedisConf)
  26. }
  27. func newTestMiddleware(rds *redis.Redis, quota int) *RateLimitMiddleware {
  28. prefix := fmt.Sprintf("test_rl_%d_%d", time.Now().UnixNano(), rand.Intn(100000))
  29. return NewRateLimitMiddleware(rds, 60, quota, prefix, false)
  30. }
  31. // TC-0546: 正常请求(未超限)
  32. func TestRateLimit_NormalRequest(t *testing.T) {
  33. rds := newTestRedis()
  34. m := newTestMiddleware(rds, 10)
  35. nextCalled := false
  36. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  37. nextCalled = true
  38. w.WriteHeader(http.StatusOK)
  39. })
  40. req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  41. req.Header.Set("X-Forwarded-For", uniqueIP())
  42. w := httptest.NewRecorder()
  43. handler(w, req)
  44. assert.True(t, nextCalled, "next handler should be called")
  45. assert.Equal(t, http.StatusOK, w.Code)
  46. }
  47. // TC-0547: 超限请求被拒绝
  48. func TestRateLimit_OverQuotaRejected(t *testing.T) {
  49. rds := newTestRedis()
  50. m := newTestMiddleware(rds, 2)
  51. ip := uniqueIP()
  52. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  53. httpx.OkJson(w, nil)
  54. })
  55. for i := 0; i < 2; i++ {
  56. req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  57. req.Header.Set("X-Forwarded-For", ip)
  58. w := httptest.NewRecorder()
  59. handler(w, req)
  60. }
  61. req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  62. req.Header.Set("X-Forwarded-For", ip)
  63. w := httptest.NewRecorder()
  64. handler(w, req)
  65. var body response.Body
  66. err := json.Unmarshal(w.Body.Bytes(), &body)
  67. require.NoError(t, err)
  68. assert.Equal(t, 429, body.Code)
  69. assert.Equal(t, "请求过于频繁,请稍后再试", body.Msg)
  70. }
  71. // TC-0548: behindProxy=false时XFF被忽略
  72. func TestRateLimit_XForwardedForIgnored(t *testing.T) {
  73. rds := newTestRedis()
  74. m := newTestMiddleware(rds, 1)
  75. var nextCount int
  76. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  77. nextCount++
  78. w.WriteHeader(http.StatusOK)
  79. })
  80. remoteAddr := uniqueIP() + ":12345"
  81. req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  82. req.RemoteAddr = remoteAddr
  83. req.Header.Set("X-Forwarded-For", uniqueIP())
  84. handler(httptest.NewRecorder(), req)
  85. assert.Equal(t, 1, nextCount)
  86. req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  87. req2.RemoteAddr = remoteAddr
  88. req2.Header.Set("X-Forwarded-For", uniqueIP())
  89. handler(httptest.NewRecorder(), req2)
  90. assert.Equal(t, 1, nextCount, "different X-Forwarded-For should NOT bypass rate limit; RemoteAddr is used")
  91. }
  92. // TC-0549: behindProxy=false时X-Real-IP被忽略
  93. func TestRateLimit_XRealIPIgnored(t *testing.T) {
  94. rds := newTestRedis()
  95. m := newTestMiddleware(rds, 1)
  96. var nextCount int
  97. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  98. nextCount++
  99. w.WriteHeader(http.StatusOK)
  100. })
  101. remoteAddr := uniqueIP() + ":12345"
  102. req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  103. req.RemoteAddr = remoteAddr
  104. req.Header.Set("X-Real-IP", uniqueIP())
  105. handler(httptest.NewRecorder(), req)
  106. assert.Equal(t, 1, nextCount)
  107. req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  108. req2.RemoteAddr = remoteAddr
  109. req2.Header.Set("X-Real-IP", uniqueIP())
  110. handler(httptest.NewRecorder(), req2)
  111. assert.Equal(t, 1, nextCount, "different X-Real-IP should NOT bypass rate limit; RemoteAddr is used")
  112. }
  113. // TC-0550: IP从RemoteAddr解析
  114. func TestRateLimit_IPFromRemoteAddr(t *testing.T) {
  115. rds := newTestRedis()
  116. m := newTestMiddleware(rds, 1)
  117. ip := uniqueIP()
  118. remoteAddr := ip + ":12345"
  119. var gotNext bool
  120. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  121. gotNext = true
  122. w.WriteHeader(http.StatusOK)
  123. })
  124. req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  125. req.RemoteAddr = remoteAddr
  126. w := httptest.NewRecorder()
  127. handler(w, req)
  128. assert.True(t, gotNext)
  129. gotNext = false
  130. req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  131. req2.RemoteAddr = remoteAddr
  132. w2 := httptest.NewRecorder()
  133. handler(w2, req2)
  134. assert.False(t, gotNext, "should be rate limited by RemoteAddr")
  135. }
  136. // TC-0551: 不同RemoteAddr独立限流
  137. func TestRateLimit_DifferentIPsIndependent(t *testing.T) {
  138. rds := newTestRedis()
  139. m := newTestMiddleware(rds, 1)
  140. addr1 := uniqueIP() + ":12345"
  141. addr2 := uniqueIP() + ":12345"
  142. var nextCount int
  143. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  144. nextCount++
  145. w.WriteHeader(http.StatusOK)
  146. })
  147. req1 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  148. req1.RemoteAddr = addr1
  149. handler(httptest.NewRecorder(), req1)
  150. assert.Equal(t, 1, nextCount)
  151. req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  152. req2.RemoteAddr = addr2
  153. handler(httptest.NewRecorder(), req2)
  154. assert.Equal(t, 2, nextCount, "different RemoteAddr should have independent quotas")
  155. req3 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  156. req3.RemoteAddr = addr1
  157. handler(httptest.NewRecorder(), req3)
  158. assert.Equal(t, 2, nextCount, "addr1 should be over quota")
  159. req4 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  160. req4.RemoteAddr = addr2
  161. handler(httptest.NewRecorder(), req4)
  162. assert.Equal(t, 2, nextCount, "addr2 should be over quota")
  163. }
  164. func newTestMiddlewareProxy(rds *redis.Redis, quota int) *RateLimitMiddleware {
  165. prefix := fmt.Sprintf("test_rl_%d_%d", time.Now().UnixNano(), rand.Intn(100000))
  166. return NewRateLimitMiddleware(rds, 60, quota, prefix, true)
  167. }
  168. // TC-0552: behindProxy=true时信任X-Real-IP
  169. func TestRateLimit_BehindProxy_TrustsXRealIP(t *testing.T) {
  170. rds := newTestRedis()
  171. m := newTestMiddlewareProxy(rds, 1)
  172. remoteAddr := uniqueIP() + ":12345"
  173. var nextCount int
  174. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  175. nextCount++
  176. w.WriteHeader(http.StatusOK)
  177. })
  178. req1 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  179. req1.RemoteAddr = remoteAddr
  180. req1.Header.Set("X-Real-IP", uniqueIP())
  181. handler(httptest.NewRecorder(), req1)
  182. assert.Equal(t, 1, nextCount)
  183. req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  184. req2.RemoteAddr = remoteAddr
  185. req2.Header.Set("X-Real-IP", uniqueIP())
  186. handler(httptest.NewRecorder(), req2)
  187. assert.Equal(t, 2, nextCount, "different X-Real-IP should have independent quotas when behindProxy=true")
  188. }
  189. // TC-0553: behindProxy=true时无X-Real-IP回退RemoteAddr
  190. func TestRateLimit_BehindProxy_FallbackToRemoteAddr(t *testing.T) {
  191. rds := newTestRedis()
  192. m := newTestMiddlewareProxy(rds, 1)
  193. remoteAddr := uniqueIP() + ":12345"
  194. var nextCount int
  195. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  196. nextCount++
  197. w.WriteHeader(http.StatusOK)
  198. })
  199. req1 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  200. req1.RemoteAddr = remoteAddr
  201. handler(httptest.NewRecorder(), req1)
  202. assert.Equal(t, 1, nextCount)
  203. req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  204. req2.RemoteAddr = remoteAddr
  205. handler(httptest.NewRecorder(), req2)
  206. assert.Equal(t, 1, nextCount, "should fall back to RemoteAddr when X-Real-IP is absent")
  207. }
  208. // 原 TC-0554 "behindProxy=true 时 XFF 仍被忽略" 已按 M-6 修复反转:behindProxy=true
  209. // 时 XFF 首段优先,契约由 ratelimitMiddlewareXff_audit_test.go (TC-0862~0866) 取代。
  210. // =============================================================================
  211. // audit L-2 回归:产品登录与管后登录必须使用独立的限流桶
  212. // 修复前:两个入口共享同一个 keyPrefix,导致攻击者对产品登录的爆破会消耗管后登录的配额(或反之)
  213. // 修复后:ProductLoginRateLimit 使用 "...:rl:login:product",AdminLoginRateLimit 使用 "...:rl:login:admin"
  214. // =============================================================================
  215. // TC-0710: 两个不同 keyPrefix 的限流中间件在同一 IP 上互不影响
  216. func TestRateLimit_ProductAndAdminBucketsAreIndependent(t *testing.T) {
  217. rds := newTestRedis()
  218. // 模拟 servicecontext.go 里的两个独立桶
  219. prefixBase := fmt.Sprintf("test_rl_l2_%d_%d", time.Now().UnixNano(), rand.Intn(100000))
  220. productM := NewRateLimitMiddleware(rds, 60, 1, prefixBase+":rl:login:product", false)
  221. adminM := NewRateLimitMiddleware(rds, 60, 1, prefixBase+":rl:login:admin", false)
  222. ip := uniqueIP()
  223. remoteAddr := ip + ":12345"
  224. var productNext, adminNext int
  225. productHandler := productM.Handle(func(w http.ResponseWriter, r *http.Request) {
  226. productNext++
  227. w.WriteHeader(http.StatusOK)
  228. })
  229. adminHandler := adminM.Handle(func(w http.ResponseWriter, r *http.Request) {
  230. adminNext++
  231. w.WriteHeader(http.StatusOK)
  232. })
  233. // 对产品登录打一枪(配额=1,刚好用完)
  234. req1 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  235. req1.RemoteAddr = remoteAddr
  236. productHandler(httptest.NewRecorder(), req1)
  237. require.Equal(t, 1, productNext)
  238. // 再对产品登录打一枪 → 被限流
  239. req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  240. req2.RemoteAddr = remoteAddr
  241. productHandler(httptest.NewRecorder(), req2)
  242. require.Equal(t, 1, productNext, "产品登录桶已耗尽")
  243. // 关键:同 IP 对管后登录仍应放行(独立桶)
  244. req3 := httptest.NewRequest(http.MethodPost, "/api/auth/adminLogin", nil)
  245. req3.RemoteAddr = remoteAddr
  246. adminHandler(httptest.NewRecorder(), req3)
  247. assert.Equal(t, 1, adminNext,
  248. "audit L-2: 产品登录限流不应影响管后登录(不同 keyPrefix)")
  249. // 再打管后一枪 → 管后桶也应耗尽,但产品桶已经耗尽在先
  250. req4 := httptest.NewRequest(http.MethodPost, "/api/auth/adminLogin", nil)
  251. req4.RemoteAddr = remoteAddr
  252. adminHandler(httptest.NewRecorder(), req4)
  253. assert.Equal(t, 1, adminNext, "管后桶配额=1,第二次应被限流")
  254. }
  255. // TC-0555: RemoteAddr无端口格式
  256. func TestExtractClientIP_RemoteAddrNoPort(t *testing.T) {
  257. req := httptest.NewRequest(http.MethodPost, "/api/test", nil)
  258. req.RemoteAddr = "1.2.3.4"
  259. ip := ExtractClientIP(req, false)
  260. assert.Equal(t, "1.2.3.4", ip, "should return raw RemoteAddr when SplitHostPort fails")
  261. ip2 := ExtractClientIP(req, true)
  262. assert.Equal(t, "1.2.3.4", ip2, "behindProxy=true without X-Real-IP should also fallback")
  263. }