saveinmed/backend-old/internal/http/middleware/ratelimit_test.go
2026-01-16 10:51:52 -03:00

125 lines
3 KiB
Go

package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestRateLimiter_Allow(t *testing.T) {
rl := NewRateLimiter(5, 5, time.Second)
// First 5 requests should pass
for i := 0; i < 5; i++ {
if !rl.allow("test-ip") {
t.Errorf("request %d should be allowed", i+1)
}
}
// 6th request should be blocked
if rl.allow("test-ip") {
t.Error("6th request should be blocked")
}
}
func TestRateLimiter_DifferentIPs(t *testing.T) {
rl := NewRateLimiter(2, 2, time.Second)
// IP1 uses its quota
rl.allow("ip1")
rl.allow("ip1")
if rl.allow("ip1") {
t.Error("ip1 should be blocked after 2 requests")
}
// IP2 should still work
if !rl.allow("ip2") {
t.Error("ip2 should be allowed")
}
}
func TestRateLimiter_Middleware(t *testing.T) {
rl := NewRateLimiter(2, 2, time.Second)
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// First 2 requests pass
for i := 0; i < 2; i++ {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "192.168.1.1:12345"
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("request %d should return 200, got %d", i+1, rec.Code)
}
}
// 3rd request blocked
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "192.168.1.1:12345"
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusTooManyRequests {
t.Errorf("expected 429, got %d", rec.Code)
}
if rec.Header().Get("Retry-After") == "" {
t.Error("expected Retry-After header")
}
}
func TestRateLimiter_XForwardedFor(t *testing.T) {
rl := NewRateLimiter(1, 1, time.Second)
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Request with X-Forwarded-For
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-Forwarded-For", "10.0.0.1")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Error("first request should pass")
}
// Second request from same forwarded IP should be blocked
req2 := httptest.NewRequest(http.MethodGet, "/", nil)
req2.Header.Set("X-Forwarded-For", "10.0.0.1")
rec2 := httptest.NewRecorder()
handler.ServeHTTP(rec2, req2)
if rec2.Code != http.StatusTooManyRequests {
t.Errorf("expected 429, got %d", rec2.Code)
}
}
func TestDefaultRateLimiter(t *testing.T) {
rl := DefaultRateLimiter()
if rl.rate != 100 {
t.Errorf("expected rate 100, got %d", rl.rate)
}
if rl.burst != 100 {
t.Errorf("expected burst 100, got %d", rl.burst)
}
}
func TestRateLimiter_Cleanup(t *testing.T) {
rl := NewRateLimiter(1, 1, time.Second)
rl.allow("old-ip")
// Simulate old bucket
rl.mu.Lock()
rl.buckets["old-ip"].lastFill = time.Now().Add(-2 * time.Hour)
rl.mu.Unlock()
rl.Cleanup(time.Hour)
rl.mu.Lock()
_, exists := rl.buckets["old-ip"]
rl.mu.Unlock()
if exists {
t.Error("old bucket should be cleaned up")
}
}