diff --git a/backend/internal/http/middleware/middleware_test.go b/backend/internal/http/middleware/middleware_test.go index 0651d28..b45d33f 100644 --- a/backend/internal/http/middleware/middleware_test.go +++ b/backend/internal/http/middleware/middleware_test.go @@ -391,3 +391,31 @@ func TestRequireAuthWrongSigningMethod(t *testing.T) { t.Errorf("expected 401, got %d", rec.Code) } } + +// --- Security Headers Tests --- + +func TestSecurityHeaders(t *testing.T) { + handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Header().Get("X-Content-Type-Options") != "nosniff" { + t.Error("expected X-Content-Type-Options: nosniff") + } + if rec.Header().Get("X-Frame-Options") != "DENY" { + t.Error("expected X-Frame-Options: DENY") + } + if rec.Header().Get("X-XSS-Protection") != "1; mode=block" { + t.Error("expected X-XSS-Protection: 1; mode=block") + } + if rec.Header().Get("Content-Security-Policy") != "default-src 'none'" { + t.Error("expected Content-Security-Policy: default-src 'none'") + } + if rec.Header().Get("Cache-Control") != "no-store, max-age=0" { + t.Error("expected Cache-Control: no-store, max-age=0") + } +} diff --git a/backend/internal/http/middleware/ratelimit.go b/backend/internal/http/middleware/ratelimit.go new file mode 100644 index 0000000..d740e8d --- /dev/null +++ b/backend/internal/http/middleware/ratelimit.go @@ -0,0 +1,103 @@ +package middleware + +import ( + "net/http" + "sync" + "time" +) + +// RateLimiter provides token bucket rate limiting per IP. +type RateLimiter struct { + buckets map[string]*bucket + mu sync.Mutex + rate int // tokens per interval + burst int // max tokens + per time.Duration // refill interval +} + +type bucket struct { + tokens int + lastFill time.Time +} + +// NewRateLimiter creates a rate limiter. +// Default: 100 requests per minute per IP. +func NewRateLimiter(rate, burst int, per time.Duration) *RateLimiter { + return &RateLimiter{ + buckets: make(map[string]*bucket), + rate: rate, + burst: burst, + per: per, + } +} + +// DefaultRateLimiter returns a limiter with sensible defaults. +func DefaultRateLimiter() *RateLimiter { + return NewRateLimiter(100, 100, time.Minute) +} + +// Middleware returns an HTTP middleware that enforces rate limiting. +func (rl *RateLimiter) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ip := getClientIP(r) + + if !rl.allow(ip) { + w.Header().Set("Retry-After", "60") + w.WriteHeader(http.StatusTooManyRequests) + return + } + + next.ServeHTTP(w, r) + }) +} + +func (rl *RateLimiter) allow(key string) bool { + rl.mu.Lock() + defer rl.mu.Unlock() + + b, exists := rl.buckets[key] + if !exists { + b = &bucket{tokens: rl.burst, lastFill: time.Now()} + rl.buckets[key] = b + } + + // Refill tokens based on elapsed time + now := time.Now() + elapsed := now.Sub(b.lastFill) + tokensToAdd := int(elapsed/rl.per) * rl.rate + if tokensToAdd > 0 { + b.tokens = min(b.tokens+tokensToAdd, rl.burst) + b.lastFill = now + } + + if b.tokens > 0 { + b.tokens-- + return true + } + + return false +} + +func getClientIP(r *http.Request) string { + // Check common proxy headers + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + return xff + } + if xri := r.Header.Get("X-Real-IP"); xri != "" { + return xri + } + return r.RemoteAddr +} + +// Cleanup removes stale buckets (call periodically) +func (rl *RateLimiter) Cleanup(maxAge time.Duration) { + rl.mu.Lock() + defer rl.mu.Unlock() + + now := time.Now() + for key, b := range rl.buckets { + if now.Sub(b.lastFill) > maxAge { + delete(rl.buckets, key) + } + } +} diff --git a/backend/internal/http/middleware/ratelimit_test.go b/backend/internal/http/middleware/ratelimit_test.go new file mode 100644 index 0000000..86ead4f --- /dev/null +++ b/backend/internal/http/middleware/ratelimit_test.go @@ -0,0 +1,125 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestRateLimiter_Allow(t *testing.T) { + rl := NewRateLimiter(5, 5, time.Second) + + // First 5 requests should pass + for i := 0; i < 5; i++ { + if !rl.allow("test-ip") { + t.Errorf("request %d should be allowed", i+1) + } + } + + // 6th request should be blocked + if rl.allow("test-ip") { + t.Error("6th request should be blocked") + } +} + +func TestRateLimiter_DifferentIPs(t *testing.T) { + rl := NewRateLimiter(2, 2, time.Second) + + // IP1 uses its quota + rl.allow("ip1") + rl.allow("ip1") + if rl.allow("ip1") { + t.Error("ip1 should be blocked after 2 requests") + } + + // IP2 should still work + if !rl.allow("ip2") { + t.Error("ip2 should be allowed") + } +} + +func TestRateLimiter_Middleware(t *testing.T) { + rl := NewRateLimiter(2, 2, time.Second) + handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // First 2 requests pass + for i := 0; i < 2; i++ { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = "192.168.1.1:12345" + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Errorf("request %d should return 200, got %d", i+1, rec.Code) + } + } + + // 3rd request blocked + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = "192.168.1.1:12345" + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + if rec.Code != http.StatusTooManyRequests { + t.Errorf("expected 429, got %d", rec.Code) + } + if rec.Header().Get("Retry-After") == "" { + t.Error("expected Retry-After header") + } +} + +func TestRateLimiter_XForwardedFor(t *testing.T) { + rl := NewRateLimiter(1, 1, time.Second) + handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Request with X-Forwarded-For + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-Forwarded-For", "10.0.0.1") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Error("first request should pass") + } + + // Second request from same forwarded IP should be blocked + req2 := httptest.NewRequest(http.MethodGet, "/", nil) + req2.Header.Set("X-Forwarded-For", "10.0.0.1") + rec2 := httptest.NewRecorder() + handler.ServeHTTP(rec2, req2) + if rec2.Code != http.StatusTooManyRequests { + t.Errorf("expected 429, got %d", rec2.Code) + } +} + +func TestDefaultRateLimiter(t *testing.T) { + rl := DefaultRateLimiter() + if rl.rate != 100 { + t.Errorf("expected rate 100, got %d", rl.rate) + } + if rl.burst != 100 { + t.Errorf("expected burst 100, got %d", rl.burst) + } +} + +func TestRateLimiter_Cleanup(t *testing.T) { + rl := NewRateLimiter(1, 1, time.Second) + rl.allow("old-ip") + + // Simulate old bucket + rl.mu.Lock() + rl.buckets["old-ip"].lastFill = time.Now().Add(-2 * time.Hour) + rl.mu.Unlock() + + rl.Cleanup(time.Hour) + + rl.mu.Lock() + _, exists := rl.buckets["old-ip"] + rl.mu.Unlock() + + if exists { + t.Error("old bucket should be cleaned up") + } +} diff --git a/backend/internal/http/middleware/security.go b/backend/internal/http/middleware/security.go new file mode 100644 index 0000000..fda4543 --- /dev/null +++ b/backend/internal/http/middleware/security.go @@ -0,0 +1,33 @@ +package middleware + +import ( + "net/http" +) + +// SecurityHeaders adds common security headers to responses. +func SecurityHeaders(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Prevent MIME type sniffing + w.Header().Set("X-Content-Type-Options", "nosniff") + + // Prevent clickjacking + w.Header().Set("X-Frame-Options", "DENY") + + // Enable XSS filter + w.Header().Set("X-XSS-Protection", "1; mode=block") + + // Content Security Policy (strict for API) + w.Header().Set("Content-Security-Policy", "default-src 'none'") + + // Referrer Policy + w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") + + // Cache control for API responses + w.Header().Set("Cache-Control", "no-store, max-age=0") + + // HSTS (HTTP Strict Transport Security) - only in production + // w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains") + + next.ServeHTTP(w, r) + }) +}