feat(security): add rate limiting and security headers middleware
Rate Limiting (ratelimit.go): - Token bucket algorithm per IP - Default: 100 requests/minute - X-Forwarded-For support - Cleanup for stale buckets - 7 tests (ratelimit_test.go) Security Headers (security.go): - X-Content-Type-Options: nosniff - X-Frame-Options: DENY - X-XSS-Protection: 1; mode=block - Content-Security-Policy: default-src 'none' - Referrer-Policy: strict-origin-when-cross-origin - Cache-Control: no-store, max-age=0 Middleware coverage: 97.3% -> 95.8% (new code added)
This commit is contained in:
parent
45d34f36c8
commit
beffeb8268
4 changed files with 289 additions and 0 deletions
|
|
@ -391,3 +391,31 @@ func TestRequireAuthWrongSigningMethod(t *testing.T) {
|
||||||
t.Errorf("expected 401, got %d", rec.Code)
|
t.Errorf("expected 401, got %d", rec.Code)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- Security Headers Tests ---
|
||||||
|
|
||||||
|
func TestSecurityHeaders(t *testing.T) {
|
||||||
|
handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Header().Get("X-Content-Type-Options") != "nosniff" {
|
||||||
|
t.Error("expected X-Content-Type-Options: nosniff")
|
||||||
|
}
|
||||||
|
if rec.Header().Get("X-Frame-Options") != "DENY" {
|
||||||
|
t.Error("expected X-Frame-Options: DENY")
|
||||||
|
}
|
||||||
|
if rec.Header().Get("X-XSS-Protection") != "1; mode=block" {
|
||||||
|
t.Error("expected X-XSS-Protection: 1; mode=block")
|
||||||
|
}
|
||||||
|
if rec.Header().Get("Content-Security-Policy") != "default-src 'none'" {
|
||||||
|
t.Error("expected Content-Security-Policy: default-src 'none'")
|
||||||
|
}
|
||||||
|
if rec.Header().Get("Cache-Control") != "no-store, max-age=0" {
|
||||||
|
t.Error("expected Cache-Control: no-store, max-age=0")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
103
backend/internal/http/middleware/ratelimit.go
Normal file
103
backend/internal/http/middleware/ratelimit.go
Normal file
|
|
@ -0,0 +1,103 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RateLimiter provides token bucket rate limiting per IP.
|
||||||
|
type RateLimiter struct {
|
||||||
|
buckets map[string]*bucket
|
||||||
|
mu sync.Mutex
|
||||||
|
rate int // tokens per interval
|
||||||
|
burst int // max tokens
|
||||||
|
per time.Duration // refill interval
|
||||||
|
}
|
||||||
|
|
||||||
|
type bucket struct {
|
||||||
|
tokens int
|
||||||
|
lastFill time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRateLimiter creates a rate limiter.
|
||||||
|
// Default: 100 requests per minute per IP.
|
||||||
|
func NewRateLimiter(rate, burst int, per time.Duration) *RateLimiter {
|
||||||
|
return &RateLimiter{
|
||||||
|
buckets: make(map[string]*bucket),
|
||||||
|
rate: rate,
|
||||||
|
burst: burst,
|
||||||
|
per: per,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultRateLimiter returns a limiter with sensible defaults.
|
||||||
|
func DefaultRateLimiter() *RateLimiter {
|
||||||
|
return NewRateLimiter(100, 100, time.Minute)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Middleware returns an HTTP middleware that enforces rate limiting.
|
||||||
|
func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ip := getClientIP(r)
|
||||||
|
|
||||||
|
if !rl.allow(ip) {
|
||||||
|
w.Header().Set("Retry-After", "60")
|
||||||
|
w.WriteHeader(http.StatusTooManyRequests)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rl *RateLimiter) allow(key string) bool {
|
||||||
|
rl.mu.Lock()
|
||||||
|
defer rl.mu.Unlock()
|
||||||
|
|
||||||
|
b, exists := rl.buckets[key]
|
||||||
|
if !exists {
|
||||||
|
b = &bucket{tokens: rl.burst, lastFill: time.Now()}
|
||||||
|
rl.buckets[key] = b
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refill tokens based on elapsed time
|
||||||
|
now := time.Now()
|
||||||
|
elapsed := now.Sub(b.lastFill)
|
||||||
|
tokensToAdd := int(elapsed/rl.per) * rl.rate
|
||||||
|
if tokensToAdd > 0 {
|
||||||
|
b.tokens = min(b.tokens+tokensToAdd, rl.burst)
|
||||||
|
b.lastFill = now
|
||||||
|
}
|
||||||
|
|
||||||
|
if b.tokens > 0 {
|
||||||
|
b.tokens--
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func getClientIP(r *http.Request) string {
|
||||||
|
// Check common proxy headers
|
||||||
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||||
|
return xff
|
||||||
|
}
|
||||||
|
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||||
|
return xri
|
||||||
|
}
|
||||||
|
return r.RemoteAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup removes stale buckets (call periodically)
|
||||||
|
func (rl *RateLimiter) Cleanup(maxAge time.Duration) {
|
||||||
|
rl.mu.Lock()
|
||||||
|
defer rl.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
for key, b := range rl.buckets {
|
||||||
|
if now.Sub(b.lastFill) > maxAge {
|
||||||
|
delete(rl.buckets, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
125
backend/internal/http/middleware/ratelimit_test.go
Normal file
125
backend/internal/http/middleware/ratelimit_test.go
Normal file
|
|
@ -0,0 +1,125 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRateLimiter_Allow(t *testing.T) {
|
||||||
|
rl := NewRateLimiter(5, 5, time.Second)
|
||||||
|
|
||||||
|
// First 5 requests should pass
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
if !rl.allow("test-ip") {
|
||||||
|
t.Errorf("request %d should be allowed", i+1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6th request should be blocked
|
||||||
|
if rl.allow("test-ip") {
|
||||||
|
t.Error("6th request should be blocked")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiter_DifferentIPs(t *testing.T) {
|
||||||
|
rl := NewRateLimiter(2, 2, time.Second)
|
||||||
|
|
||||||
|
// IP1 uses its quota
|
||||||
|
rl.allow("ip1")
|
||||||
|
rl.allow("ip1")
|
||||||
|
if rl.allow("ip1") {
|
||||||
|
t.Error("ip1 should be blocked after 2 requests")
|
||||||
|
}
|
||||||
|
|
||||||
|
// IP2 should still work
|
||||||
|
if !rl.allow("ip2") {
|
||||||
|
t.Error("ip2 should be allowed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiter_Middleware(t *testing.T) {
|
||||||
|
rl := NewRateLimiter(2, 2, time.Second)
|
||||||
|
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
// First 2 requests pass
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.RemoteAddr = "192.168.1.1:12345"
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Errorf("request %d should return 200, got %d", i+1, rec.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3rd request blocked
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.RemoteAddr = "192.168.1.1:12345"
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusTooManyRequests {
|
||||||
|
t.Errorf("expected 429, got %d", rec.Code)
|
||||||
|
}
|
||||||
|
if rec.Header().Get("Retry-After") == "" {
|
||||||
|
t.Error("expected Retry-After header")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiter_XForwardedFor(t *testing.T) {
|
||||||
|
rl := NewRateLimiter(1, 1, time.Second)
|
||||||
|
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Request with X-Forwarded-For
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set("X-Forwarded-For", "10.0.0.1")
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Error("first request should pass")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second request from same forwarded IP should be blocked
|
||||||
|
req2 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req2.Header.Set("X-Forwarded-For", "10.0.0.1")
|
||||||
|
rec2 := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec2, req2)
|
||||||
|
if rec2.Code != http.StatusTooManyRequests {
|
||||||
|
t.Errorf("expected 429, got %d", rec2.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultRateLimiter(t *testing.T) {
|
||||||
|
rl := DefaultRateLimiter()
|
||||||
|
if rl.rate != 100 {
|
||||||
|
t.Errorf("expected rate 100, got %d", rl.rate)
|
||||||
|
}
|
||||||
|
if rl.burst != 100 {
|
||||||
|
t.Errorf("expected burst 100, got %d", rl.burst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiter_Cleanup(t *testing.T) {
|
||||||
|
rl := NewRateLimiter(1, 1, time.Second)
|
||||||
|
rl.allow("old-ip")
|
||||||
|
|
||||||
|
// Simulate old bucket
|
||||||
|
rl.mu.Lock()
|
||||||
|
rl.buckets["old-ip"].lastFill = time.Now().Add(-2 * time.Hour)
|
||||||
|
rl.mu.Unlock()
|
||||||
|
|
||||||
|
rl.Cleanup(time.Hour)
|
||||||
|
|
||||||
|
rl.mu.Lock()
|
||||||
|
_, exists := rl.buckets["old-ip"]
|
||||||
|
rl.mu.Unlock()
|
||||||
|
|
||||||
|
if exists {
|
||||||
|
t.Error("old bucket should be cleaned up")
|
||||||
|
}
|
||||||
|
}
|
||||||
33
backend/internal/http/middleware/security.go
Normal file
33
backend/internal/http/middleware/security.go
Normal file
|
|
@ -0,0 +1,33 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SecurityHeaders adds common security headers to responses.
|
||||||
|
func SecurityHeaders(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Prevent MIME type sniffing
|
||||||
|
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||||
|
|
||||||
|
// Prevent clickjacking
|
||||||
|
w.Header().Set("X-Frame-Options", "DENY")
|
||||||
|
|
||||||
|
// Enable XSS filter
|
||||||
|
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
||||||
|
|
||||||
|
// Content Security Policy (strict for API)
|
||||||
|
w.Header().Set("Content-Security-Policy", "default-src 'none'")
|
||||||
|
|
||||||
|
// Referrer Policy
|
||||||
|
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||||
|
|
||||||
|
// Cache control for API responses
|
||||||
|
w.Header().Set("Cache-Control", "no-store, max-age=0")
|
||||||
|
|
||||||
|
// HSTS (HTTP Strict Transport Security) - only in production
|
||||||
|
// w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
Loading…
Reference in a new issue