ratelimitMiddleware_test.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. package middleware
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "github.com/stretchr/testify/assert"
  6. "github.com/stretchr/testify/require"
  7. "github.com/zeromicro/go-zero/core/stores/redis"
  8. "github.com/zeromicro/go-zero/rest/httpx"
  9. "math/rand"
  10. "net/http"
  11. "net/http/httptest"
  12. "perms-system-server/internal/response"
  13. "perms-system-server/internal/testutil"
  14. "testing"
  15. "time"
  16. )
  17. func init() {
  18. response.Setup()
  19. }
  20. func uniqueIP() string {
  21. return fmt.Sprintf("10.%d.%d.%d", rand.Intn(256), rand.Intn(256), rand.Intn(256))
  22. }
  23. func newTestRedis() *redis.Redis {
  24. cfg := testutil.GetTestConfig()
  25. return redis.MustNewRedis(cfg.CacheRedis.Nodes[0].RedisConf)
  26. }
  27. func newTestMiddleware(rds *redis.Redis, quota int) *RateLimitMiddleware {
  28. prefix := fmt.Sprintf("test_rl_%d_%d", time.Now().UnixNano(), rand.Intn(100000))
  29. return NewRateLimitMiddleware(rds, 60, quota, prefix, false)
  30. }
  31. // TC-0546: 正常请求(未超限)
  32. func TestRateLimit_NormalRequest(t *testing.T) {
  33. rds := newTestRedis()
  34. m := newTestMiddleware(rds, 10)
  35. nextCalled := false
  36. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  37. nextCalled = true
  38. w.WriteHeader(http.StatusOK)
  39. })
  40. req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  41. req.Header.Set("X-Forwarded-For", uniqueIP())
  42. w := httptest.NewRecorder()
  43. handler(w, req)
  44. assert.True(t, nextCalled, "next handler should be called")
  45. assert.Equal(t, http.StatusOK, w.Code)
  46. }
  47. // TC-0547: 超限请求被拒绝
  48. func TestRateLimit_OverQuotaRejected(t *testing.T) {
  49. rds := newTestRedis()
  50. m := newTestMiddleware(rds, 2)
  51. ip := uniqueIP()
  52. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  53. httpx.OkJson(w, nil)
  54. })
  55. for i := 0; i < 2; i++ {
  56. req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  57. req.Header.Set("X-Forwarded-For", ip)
  58. w := httptest.NewRecorder()
  59. handler(w, req)
  60. }
  61. req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  62. req.Header.Set("X-Forwarded-For", ip)
  63. w := httptest.NewRecorder()
  64. handler(w, req)
  65. var body response.Body
  66. err := json.Unmarshal(w.Body.Bytes(), &body)
  67. require.NoError(t, err)
  68. assert.False(t, body.Success)
  69. assert.Equal(t, 429, body.ErrorCode)
  70. assert.Equal(t, "请求过于频繁,请稍后再试", body.ErrorMessage)
  71. }
  72. // TC-0548: behindProxy=false时XFF被忽略
  73. func TestRateLimit_XForwardedForIgnored(t *testing.T) {
  74. rds := newTestRedis()
  75. m := newTestMiddleware(rds, 1)
  76. var nextCount int
  77. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  78. nextCount++
  79. w.WriteHeader(http.StatusOK)
  80. })
  81. remoteAddr := uniqueIP() + ":12345"
  82. req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  83. req.RemoteAddr = remoteAddr
  84. req.Header.Set("X-Forwarded-For", uniqueIP())
  85. handler(httptest.NewRecorder(), req)
  86. assert.Equal(t, 1, nextCount)
  87. req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  88. req2.RemoteAddr = remoteAddr
  89. req2.Header.Set("X-Forwarded-For", uniqueIP())
  90. handler(httptest.NewRecorder(), req2)
  91. assert.Equal(t, 1, nextCount, "different X-Forwarded-For should NOT bypass rate limit; RemoteAddr is used")
  92. }
  93. // TC-0549: behindProxy=false时X-Real-IP被忽略
  94. func TestRateLimit_XRealIPIgnored(t *testing.T) {
  95. rds := newTestRedis()
  96. m := newTestMiddleware(rds, 1)
  97. var nextCount int
  98. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  99. nextCount++
  100. w.WriteHeader(http.StatusOK)
  101. })
  102. remoteAddr := uniqueIP() + ":12345"
  103. req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  104. req.RemoteAddr = remoteAddr
  105. req.Header.Set("X-Real-IP", uniqueIP())
  106. handler(httptest.NewRecorder(), req)
  107. assert.Equal(t, 1, nextCount)
  108. req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  109. req2.RemoteAddr = remoteAddr
  110. req2.Header.Set("X-Real-IP", uniqueIP())
  111. handler(httptest.NewRecorder(), req2)
  112. assert.Equal(t, 1, nextCount, "different X-Real-IP should NOT bypass rate limit; RemoteAddr is used")
  113. }
  114. // TC-0550: IP从RemoteAddr解析
  115. func TestRateLimit_IPFromRemoteAddr(t *testing.T) {
  116. rds := newTestRedis()
  117. m := newTestMiddleware(rds, 1)
  118. ip := uniqueIP()
  119. remoteAddr := ip + ":12345"
  120. var gotNext bool
  121. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  122. gotNext = true
  123. w.WriteHeader(http.StatusOK)
  124. })
  125. req := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  126. req.RemoteAddr = remoteAddr
  127. w := httptest.NewRecorder()
  128. handler(w, req)
  129. assert.True(t, gotNext)
  130. gotNext = false
  131. req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  132. req2.RemoteAddr = remoteAddr
  133. w2 := httptest.NewRecorder()
  134. handler(w2, req2)
  135. assert.False(t, gotNext, "should be rate limited by RemoteAddr")
  136. }
  137. // TC-0551: 不同RemoteAddr独立限流
  138. func TestRateLimit_DifferentIPsIndependent(t *testing.T) {
  139. rds := newTestRedis()
  140. m := newTestMiddleware(rds, 1)
  141. addr1 := uniqueIP() + ":12345"
  142. addr2 := uniqueIP() + ":12345"
  143. var nextCount int
  144. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  145. nextCount++
  146. w.WriteHeader(http.StatusOK)
  147. })
  148. req1 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  149. req1.RemoteAddr = addr1
  150. handler(httptest.NewRecorder(), req1)
  151. assert.Equal(t, 1, nextCount)
  152. req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  153. req2.RemoteAddr = addr2
  154. handler(httptest.NewRecorder(), req2)
  155. assert.Equal(t, 2, nextCount, "different RemoteAddr should have independent quotas")
  156. req3 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  157. req3.RemoteAddr = addr1
  158. handler(httptest.NewRecorder(), req3)
  159. assert.Equal(t, 2, nextCount, "addr1 should be over quota")
  160. req4 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  161. req4.RemoteAddr = addr2
  162. handler(httptest.NewRecorder(), req4)
  163. assert.Equal(t, 2, nextCount, "addr2 should be over quota")
  164. }
  165. func newTestMiddlewareProxy(rds *redis.Redis, quota int) *RateLimitMiddleware {
  166. prefix := fmt.Sprintf("test_rl_%d_%d", time.Now().UnixNano(), rand.Intn(100000))
  167. return NewRateLimitMiddleware(rds, 60, quota, prefix, true)
  168. }
  169. // TC-0552: behindProxy=true时信任X-Real-IP
  170. func TestRateLimit_BehindProxy_TrustsXRealIP(t *testing.T) {
  171. rds := newTestRedis()
  172. m := newTestMiddlewareProxy(rds, 1)
  173. remoteAddr := uniqueIP() + ":12345"
  174. var nextCount int
  175. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  176. nextCount++
  177. w.WriteHeader(http.StatusOK)
  178. })
  179. req1 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  180. req1.RemoteAddr = remoteAddr
  181. req1.Header.Set("X-Real-IP", uniqueIP())
  182. handler(httptest.NewRecorder(), req1)
  183. assert.Equal(t, 1, nextCount)
  184. req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  185. req2.RemoteAddr = remoteAddr
  186. req2.Header.Set("X-Real-IP", uniqueIP())
  187. handler(httptest.NewRecorder(), req2)
  188. assert.Equal(t, 2, nextCount, "different X-Real-IP should have independent quotas when behindProxy=true")
  189. }
  190. // TC-0553: behindProxy=true时无X-Real-IP回退RemoteAddr
  191. func TestRateLimit_BehindProxy_FallbackToRemoteAddr(t *testing.T) {
  192. rds := newTestRedis()
  193. m := newTestMiddlewareProxy(rds, 1)
  194. remoteAddr := uniqueIP() + ":12345"
  195. var nextCount int
  196. handler := m.Handle(func(w http.ResponseWriter, r *http.Request) {
  197. nextCount++
  198. w.WriteHeader(http.StatusOK)
  199. })
  200. req1 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  201. req1.RemoteAddr = remoteAddr
  202. handler(httptest.NewRecorder(), req1)
  203. assert.Equal(t, 1, nextCount)
  204. req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  205. req2.RemoteAddr = remoteAddr
  206. handler(httptest.NewRecorder(), req2)
  207. assert.Equal(t, 1, nextCount, "should fall back to RemoteAddr when X-Real-IP is absent")
  208. }
  209. // 原 TC-0554 "behindProxy=true 时 XFF 仍被忽略" 已按 修复反转:behindProxy=true
  210. // 时 XFF 首段优先,契约由 ratelimitMiddlewareXff_audit_test.go (TC-0862~0866) 取代。
  211. // =============================================================================
  212. // audit 回归:产品登录与管后登录必须使用独立的限流桶
  213. // 修复前:两个入口共享同一个 keyPrefix,导致攻击者对产品登录的爆破会消耗管后登录的配额(或反之)
  214. // 修复后:ProductLoginRateLimit 使用 "...:rl:login:product",AdminLoginRateLimit 使用 "...:rl:login:admin"
  215. // =============================================================================
  216. // TC-0710: 两个不同 keyPrefix 的限流中间件在同一 IP 上互不影响
  217. func TestRateLimit_ProductAndAdminBucketsAreIndependent(t *testing.T) {
  218. rds := newTestRedis()
  219. // 模拟 servicecontext.go 里的两个独立桶
  220. prefixBase := fmt.Sprintf("test_rl_l2_%d_%d", time.Now().UnixNano(), rand.Intn(100000))
  221. productM := NewRateLimitMiddleware(rds, 60, 1, prefixBase+":rl:login:product", false)
  222. adminM := NewRateLimitMiddleware(rds, 60, 1, prefixBase+":rl:login:admin", false)
  223. ip := uniqueIP()
  224. remoteAddr := ip + ":12345"
  225. var productNext, adminNext int
  226. productHandler := productM.Handle(func(w http.ResponseWriter, r *http.Request) {
  227. productNext++
  228. w.WriteHeader(http.StatusOK)
  229. })
  230. adminHandler := adminM.Handle(func(w http.ResponseWriter, r *http.Request) {
  231. adminNext++
  232. w.WriteHeader(http.StatusOK)
  233. })
  234. // 对产品登录打一枪(配额=1,刚好用完)
  235. req1 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  236. req1.RemoteAddr = remoteAddr
  237. productHandler(httptest.NewRecorder(), req1)
  238. require.Equal(t, 1, productNext)
  239. // 再对产品登录打一枪 → 被限流
  240. req2 := httptest.NewRequest(http.MethodPost, "/api/auth/login", nil)
  241. req2.RemoteAddr = remoteAddr
  242. productHandler(httptest.NewRecorder(), req2)
  243. require.Equal(t, 1, productNext, "产品登录桶已耗尽")
  244. // 关键:同 IP 对管后登录仍应放行(独立桶)
  245. req3 := httptest.NewRequest(http.MethodPost, "/api/auth/adminLogin", nil)
  246. req3.RemoteAddr = remoteAddr
  247. adminHandler(httptest.NewRecorder(), req3)
  248. assert.Equal(t, 1, adminNext,
  249. "产品登录限流不应影响管后登录(不同 keyPrefix)")
  250. // 再打管后一枪 → 管后桶也应耗尽,但产品桶已经耗尽在先
  251. req4 := httptest.NewRequest(http.MethodPost, "/api/auth/adminLogin", nil)
  252. req4.RemoteAddr = remoteAddr
  253. adminHandler(httptest.NewRecorder(), req4)
  254. assert.Equal(t, 1, adminNext, "管后桶配额=1,第二次应被限流")
  255. }
  256. // TC-0555: RemoteAddr无端口格式
  257. func TestExtractClientIP_RemoteAddrNoPort(t *testing.T) {
  258. req := httptest.NewRequest(http.MethodPost, "/api/test", nil)
  259. req.RemoteAddr = "1.2.3.4"
  260. ip := ExtractClientIP(req, false)
  261. assert.Equal(t, "1.2.3.4", ip, "should return raw RemoteAddr when SplitHostPort fails")
  262. ip2 := ExtractClientIP(req, true)
  263. assert.Equal(t, "1.2.3.4", ip2, "behindProxy=true without X-Real-IP should also fallback")
  264. }
  265. func TestExtractClientIP_XFFFirstValid(t *testing.T) {
  266. r := httptest.NewRequest("POST", "/x", nil)
  267. r.Header.Set("X-Forwarded-For", "1.1.1.1, 2.2.2.2, 3.3.3.3")
  268. r.Header.Set("X-Real-IP", "9.9.9.9") // 不应被用
  269. r.RemoteAddr = "5.5.5.5:8080" // 不应被用
  270. assert.Equal(t, "1.1.1.1", ExtractClientIP(r, true),
  271. "XFF 首段合法时优先返回,高于 XRI / RemoteAddr")
  272. }
  273. // TC-0863: behindProxy=true + XFF 全非法 + XRI 合法 → fallthrough 到 XRI。
  274. func TestExtractClientIP_XFFAllInvalid_FallbackXRI(t *testing.T) {
  275. r := httptest.NewRequest("POST", "/x", nil)
  276. r.Header.Set("X-Forwarded-For", "garbage, not-an-ip")
  277. r.Header.Set("X-Real-IP", "10.0.0.1")
  278. r.RemoteAddr = "5.5.5.5:8080"
  279. assert.Equal(t, "10.0.0.1", ExtractClientIP(r, true),
  280. "XFF 全不合法应当 fallthrough 到 X-Real-IP,不得返回 garbage 或 RemoteAddr")
  281. }
  282. // TC-0864: behindProxy=true + 两头均空 → 回落到 RemoteAddr 剥端口后的 host。
  283. func TestExtractClientIP_NoHeaders_FallbackRemoteAddr(t *testing.T) {
  284. r := httptest.NewRequest("POST", "/x", nil)
  285. r.RemoteAddr = "198.51.100.9:13579"
  286. assert.Equal(t, "198.51.100.9", ExtractClientIP(r, true),
  287. "所有代理头缺失时最终仍能回落到 RemoteAddr 剥端口")
  288. }
  289. // TC-0865: behindProxy=true + XFF 首段带两端空白 → trim 后仍解析合法,返回 trimmed 结果。
  290. func TestExtractClientIP_XFFWhitespaceTrimmed(t *testing.T) {
  291. r := httptest.NewRequest("POST", "/x", nil)
  292. r.Header.Set("X-Forwarded-For", " 3.3.3.3 , 4.4.4.4")
  293. assert.Equal(t, "3.3.3.3", ExtractClientIP(r, true),
  294. "XFF 首段 trim 后合法应当被采用;严禁保留首尾空白而误判")
  295. }
  296. // TC-0866: behindProxy=false —— 完全忽略 XFF / XRI,防止客户端伪造头。
  297. func TestExtractClientIP_BehindProxyFalse_IgnoreHeaders(t *testing.T) {
  298. r := httptest.NewRequest("POST", "/x", nil)
  299. r.Header.Set("X-Forwarded-For", "1.1.1.1") // 应被忽略
  300. r.Header.Set("X-Real-IP", "2.2.2.2") // 应被忽略
  301. r.RemoteAddr = "5.5.5.5:8080"
  302. assert.Equal(t, "5.5.5.5", ExtractClientIP(r, false),
  303. "behindProxy=false 时应完全忽略客户端注入的代理头")
  304. }
  305. // 补充:XFF 包含空段("1.1.1.1,,2.2.2.2")不应 panic,空段跳过后首段合法。
  306. func TestExtractClientIP_XFFEmptySegmentsSkipped(t *testing.T) {
  307. r := httptest.NewRequest("POST", "/x", nil)
  308. r.Header.Set("X-Forwarded-For", ",,,1.1.1.1,2.2.2.2")
  309. assert.Equal(t, "1.1.1.1", ExtractClientIP(r, true),
  310. "XFF 中空段必须跳过,不得 panic 或返回空串")
  311. }
  312. // 补充:XFF 全为合法 IPv6 地址也应能返回首段。
  313. func TestExtractClientIP_XFFIPv6FirstValid(t *testing.T) {
  314. r := httptest.NewRequest("POST", "/x", nil)
  315. r.Header.Set("X-Forwarded-For", "2001:db8::1, 2001:db8::2")
  316. assert.Equal(t, "2001:db8::1", ExtractClientIP(r, true),
  317. "IPv6 也是 net.ParseIP 合法值,XFF 首段应返回 IPv6")
  318. }