feat(security): add rate limiting and security headers middleware

Rate Limiting (ratelimit.go):
- Token bucket algorithm per IP
- Default: 100 requests/minute
- X-Forwarded-For support
- Cleanup for stale buckets
- 7 tests (ratelimit_test.go)

Security Headers (security.go):
- X-Content-Type-Options: nosniff
- X-Frame-Options: DENY
- X-XSS-Protection: 1; mode=block
- Content-Security-Policy: default-src 'none'
- Referrer-Policy: strict-origin-when-cross-origin
- Cache-Control: no-store, max-age=0

Middleware coverage: 97.3% -> 95.8% (new code added)
This commit is contained in:
Tiago Yamamoto 2025-12-20 08:41:36 -03:00
parent 45d34f36c8
commit beffeb8268
4 changed files with 289 additions and 0 deletions

View file

@ -391,3 +391,31 @@ func TestRequireAuthWrongSigningMethod(t *testing.T) {
t.Errorf("expected 401, got %d", rec.Code) t.Errorf("expected 401, got %d", rec.Code)
} }
} }
// --- Security Headers Tests ---
func TestSecurityHeaders(t *testing.T) {
handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Header().Get("X-Content-Type-Options") != "nosniff" {
t.Error("expected X-Content-Type-Options: nosniff")
}
if rec.Header().Get("X-Frame-Options") != "DENY" {
t.Error("expected X-Frame-Options: DENY")
}
if rec.Header().Get("X-XSS-Protection") != "1; mode=block" {
t.Error("expected X-XSS-Protection: 1; mode=block")
}
if rec.Header().Get("Content-Security-Policy") != "default-src 'none'" {
t.Error("expected Content-Security-Policy: default-src 'none'")
}
if rec.Header().Get("Cache-Control") != "no-store, max-age=0" {
t.Error("expected Cache-Control: no-store, max-age=0")
}
}

View file

@ -0,0 +1,103 @@
package middleware
import (
"net/http"
"sync"
"time"
)
// RateLimiter provides token bucket rate limiting per IP.
type RateLimiter struct {
buckets map[string]*bucket
mu sync.Mutex
rate int // tokens per interval
burst int // max tokens
per time.Duration // refill interval
}
type bucket struct {
tokens int
lastFill time.Time
}
// NewRateLimiter creates a rate limiter.
// Default: 100 requests per minute per IP.
func NewRateLimiter(rate, burst int, per time.Duration) *RateLimiter {
return &RateLimiter{
buckets: make(map[string]*bucket),
rate: rate,
burst: burst,
per: per,
}
}
// DefaultRateLimiter returns a limiter with sensible defaults.
func DefaultRateLimiter() *RateLimiter {
return NewRateLimiter(100, 100, time.Minute)
}
// Middleware returns an HTTP middleware that enforces rate limiting.
func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := getClientIP(r)
if !rl.allow(ip) {
w.Header().Set("Retry-After", "60")
w.WriteHeader(http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
func (rl *RateLimiter) allow(key string) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
b, exists := rl.buckets[key]
if !exists {
b = &bucket{tokens: rl.burst, lastFill: time.Now()}
rl.buckets[key] = b
}
// Refill tokens based on elapsed time
now := time.Now()
elapsed := now.Sub(b.lastFill)
tokensToAdd := int(elapsed/rl.per) * rl.rate
if tokensToAdd > 0 {
b.tokens = min(b.tokens+tokensToAdd, rl.burst)
b.lastFill = now
}
if b.tokens > 0 {
b.tokens--
return true
}
return false
}
func getClientIP(r *http.Request) string {
// Check common proxy headers
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
return xff
}
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return xri
}
return r.RemoteAddr
}
// Cleanup removes stale buckets (call periodically)
func (rl *RateLimiter) Cleanup(maxAge time.Duration) {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
for key, b := range rl.buckets {
if now.Sub(b.lastFill) > maxAge {
delete(rl.buckets, key)
}
}
}

View file

@ -0,0 +1,125 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestRateLimiter_Allow(t *testing.T) {
rl := NewRateLimiter(5, 5, time.Second)
// First 5 requests should pass
for i := 0; i < 5; i++ {
if !rl.allow("test-ip") {
t.Errorf("request %d should be allowed", i+1)
}
}
// 6th request should be blocked
if rl.allow("test-ip") {
t.Error("6th request should be blocked")
}
}
func TestRateLimiter_DifferentIPs(t *testing.T) {
rl := NewRateLimiter(2, 2, time.Second)
// IP1 uses its quota
rl.allow("ip1")
rl.allow("ip1")
if rl.allow("ip1") {
t.Error("ip1 should be blocked after 2 requests")
}
// IP2 should still work
if !rl.allow("ip2") {
t.Error("ip2 should be allowed")
}
}
func TestRateLimiter_Middleware(t *testing.T) {
rl := NewRateLimiter(2, 2, time.Second)
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// First 2 requests pass
for i := 0; i < 2; i++ {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "192.168.1.1:12345"
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("request %d should return 200, got %d", i+1, rec.Code)
}
}
// 3rd request blocked
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = "192.168.1.1:12345"
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusTooManyRequests {
t.Errorf("expected 429, got %d", rec.Code)
}
if rec.Header().Get("Retry-After") == "" {
t.Error("expected Retry-After header")
}
}
func TestRateLimiter_XForwardedFor(t *testing.T) {
rl := NewRateLimiter(1, 1, time.Second)
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Request with X-Forwarded-For
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-Forwarded-For", "10.0.0.1")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Error("first request should pass")
}
// Second request from same forwarded IP should be blocked
req2 := httptest.NewRequest(http.MethodGet, "/", nil)
req2.Header.Set("X-Forwarded-For", "10.0.0.1")
rec2 := httptest.NewRecorder()
handler.ServeHTTP(rec2, req2)
if rec2.Code != http.StatusTooManyRequests {
t.Errorf("expected 429, got %d", rec2.Code)
}
}
func TestDefaultRateLimiter(t *testing.T) {
rl := DefaultRateLimiter()
if rl.rate != 100 {
t.Errorf("expected rate 100, got %d", rl.rate)
}
if rl.burst != 100 {
t.Errorf("expected burst 100, got %d", rl.burst)
}
}
func TestRateLimiter_Cleanup(t *testing.T) {
rl := NewRateLimiter(1, 1, time.Second)
rl.allow("old-ip")
// Simulate old bucket
rl.mu.Lock()
rl.buckets["old-ip"].lastFill = time.Now().Add(-2 * time.Hour)
rl.mu.Unlock()
rl.Cleanup(time.Hour)
rl.mu.Lock()
_, exists := rl.buckets["old-ip"]
rl.mu.Unlock()
if exists {
t.Error("old bucket should be cleaned up")
}
}

View file

@ -0,0 +1,33 @@
package middleware
import (
"net/http"
)
// SecurityHeaders adds common security headers to responses.
func SecurityHeaders(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Prevent MIME type sniffing
w.Header().Set("X-Content-Type-Options", "nosniff")
// Prevent clickjacking
w.Header().Set("X-Frame-Options", "DENY")
// Enable XSS filter
w.Header().Set("X-XSS-Protection", "1; mode=block")
// Content Security Policy (strict for API)
w.Header().Set("Content-Security-Policy", "default-src 'none'")
// Referrer Policy
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
// Cache control for API responses
w.Header().Set("Cache-Control", "no-store, max-age=0")
// HSTS (HTTP Strict Transport Security) - only in production
// w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
next.ServeHTTP(w, r)
})
}