package middleware import ( "net/http" "sync" "time" ) // RateLimiter implements a simple in-memory rate limiter type RateLimiter struct { visitors map[string]*visitor mu sync.RWMutex rate int // requests allowed window time.Duration // time window } type visitor struct { count int lastReset time.Time } // NewRateLimiter creates a rate limiter with specified requests per window func NewRateLimiter(rate int, window time.Duration) *RateLimiter { rl := &RateLimiter{ visitors: make(map[string]*visitor), rate: rate, window: window, } // Cleanup old entries periodically go rl.cleanup() return rl } func (rl *RateLimiter) cleanup() { for { time.Sleep(rl.window) rl.mu.Lock() for ip, v := range rl.visitors { if time.Since(v.lastReset) > rl.window*2 { delete(rl.visitors, ip) } } rl.mu.Unlock() } } func (rl *RateLimiter) isAllowed(ip string) bool { rl.mu.Lock() defer rl.mu.Unlock() v, exists := rl.visitors[ip] now := time.Now() if !exists { rl.visitors[ip] = &visitor{count: 1, lastReset: now} return true } // Reset window if needed if now.Sub(v.lastReset) > rl.window { v.count = 1 v.lastReset = now return true } if v.count >= rl.rate { return false } v.count++ return true } // getIP extracts client IP from request func getIP(r *http.Request) string { // Check X-Forwarded-For first (for proxied requests) xff := r.Header.Get("X-Forwarded-For") if xff != "" { // Take the first IP in the chain for i := 0; i < len(xff); i++ { if xff[i] == ',' { return xff[:i] } } return xff } // Check X-Real-IP xri := r.Header.Get("X-Real-IP") if xri != "" { return xri } // Fallback to RemoteAddr return r.RemoteAddr } // RateLimitMiddleware returns a middleware that limits requests per IP func RateLimitMiddleware(rate int, window time.Duration) func(http.Handler) http.Handler { limiter := NewRateLimiter(rate, window) return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ip := getIP(r) if !limiter.isAllowed(ip) { w.Header().Set("Retry-After", "60") http.Error(w, "Rate limit exceeded. Please try again later.", http.StatusTooManyRequests) return } next.ServeHTTP(w, r) }) } }