saveinmed/backend/internal/http/middleware/ratelimit.go
Tiago Yamamoto beffeb8268 feat(security): add rate limiting and security headers middleware
Rate Limiting (ratelimit.go):
- Token bucket algorithm per IP
- Default: 100 requests/minute
- X-Forwarded-For support
- Cleanup for stale buckets
- 7 tests (ratelimit_test.go)

Security Headers (security.go):
- X-Content-Type-Options: nosniff
- X-Frame-Options: DENY
- X-XSS-Protection: 1; mode=block
- Content-Security-Policy: default-src 'none'
- Referrer-Policy: strict-origin-when-cross-origin
- Cache-Control: no-store, max-age=0

Middleware coverage: 97.3% -> 95.8% (new code added)
2025-12-20 08:41:36 -03:00

103 lines
2.2 KiB
Go

package middleware
import (
"net/http"
"sync"
"time"
)
// RateLimiter provides token bucket rate limiting per IP.
type RateLimiter struct {
buckets map[string]*bucket
mu sync.Mutex
rate int // tokens per interval
burst int // max tokens
per time.Duration // refill interval
}
type bucket struct {
tokens int
lastFill time.Time
}
// NewRateLimiter creates a rate limiter.
// Default: 100 requests per minute per IP.
func NewRateLimiter(rate, burst int, per time.Duration) *RateLimiter {
return &RateLimiter{
buckets: make(map[string]*bucket),
rate: rate,
burst: burst,
per: per,
}
}
// DefaultRateLimiter returns a limiter with sensible defaults.
func DefaultRateLimiter() *RateLimiter {
return NewRateLimiter(100, 100, time.Minute)
}
// Middleware returns an HTTP middleware that enforces rate limiting.
func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := getClientIP(r)
if !rl.allow(ip) {
w.Header().Set("Retry-After", "60")
w.WriteHeader(http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
func (rl *RateLimiter) allow(key string) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
b, exists := rl.buckets[key]
if !exists {
b = &bucket{tokens: rl.burst, lastFill: time.Now()}
rl.buckets[key] = b
}
// Refill tokens based on elapsed time
now := time.Now()
elapsed := now.Sub(b.lastFill)
tokensToAdd := int(elapsed/rl.per) * rl.rate
if tokensToAdd > 0 {
b.tokens = min(b.tokens+tokensToAdd, rl.burst)
b.lastFill = now
}
if b.tokens > 0 {
b.tokens--
return true
}
return false
}
func getClientIP(r *http.Request) string {
// Check common proxy headers
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
return xff
}
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return xri
}
return r.RemoteAddr
}
// Cleanup removes stale buckets (call periodically)
func (rl *RateLimiter) Cleanup(maxAge time.Duration) {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
for key, b := range rl.buckets {
if now.Sub(b.lastFill) > maxAge {
delete(rl.buckets, key)
}
}
}