package middleware import ( "net/http" "net/http/httptest" "testing" "time" ) func TestRateLimiter_Allow(t *testing.T) { rl := NewRateLimiter(5, 5, time.Second) // First 5 requests should pass for i := 0; i < 5; i++ { if !rl.allow("test-ip") { t.Errorf("request %d should be allowed", i+1) } } // 6th request should be blocked if rl.allow("test-ip") { t.Error("6th request should be blocked") } } func TestRateLimiter_DifferentIPs(t *testing.T) { rl := NewRateLimiter(2, 2, time.Second) // IP1 uses its quota rl.allow("ip1") rl.allow("ip1") if rl.allow("ip1") { t.Error("ip1 should be blocked after 2 requests") } // IP2 should still work if !rl.allow("ip2") { t.Error("ip2 should be allowed") } } func TestRateLimiter_Middleware(t *testing.T) { rl := NewRateLimiter(2, 2, time.Second) handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) // First 2 requests pass for i := 0; i < 2; i++ { req := httptest.NewRequest(http.MethodGet, "/", nil) req.RemoteAddr = "192.168.1.1:12345" rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("request %d should return 200, got %d", i+1, rec.Code) } } // 3rd request blocked req := httptest.NewRequest(http.MethodGet, "/", nil) req.RemoteAddr = "192.168.1.1:12345" rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusTooManyRequests { t.Errorf("expected 429, got %d", rec.Code) } if rec.Header().Get("Retry-After") == "" { t.Error("expected Retry-After header") } } func TestRateLimiter_XForwardedFor(t *testing.T) { rl := NewRateLimiter(1, 1, time.Second) handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) // Request with X-Forwarded-For req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("X-Forwarded-For", "10.0.0.1") rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Error("first request should pass") } // Second request from same forwarded IP should be blocked req2 := httptest.NewRequest(http.MethodGet, "/", nil) req2.Header.Set("X-Forwarded-For", "10.0.0.1") rec2 := httptest.NewRecorder() handler.ServeHTTP(rec2, req2) if rec2.Code != http.StatusTooManyRequests { t.Errorf("expected 429, got %d", rec2.Code) } } func TestDefaultRateLimiter(t *testing.T) { rl := DefaultRateLimiter() if rl.rate != 100 { t.Errorf("expected rate 100, got %d", rl.rate) } if rl.burst != 100 { t.Errorf("expected burst 100, got %d", rl.burst) } } func TestRateLimiter_Cleanup(t *testing.T) { rl := NewRateLimiter(1, 1, time.Second) rl.allow("old-ip") // Simulate old bucket rl.mu.Lock() rl.buckets["old-ip"].lastFill = time.Now().Add(-2 * time.Hour) rl.mu.Unlock() rl.Cleanup(time.Hour) rl.mu.Lock() _, exists := rl.buckets["old-ip"] rl.mu.Unlock() if exists { t.Error("old bucket should be cleaned up") } }