package middleware import ( "net/http" "strings" ) // CORSConfig holds the configuration for CORS middleware. type CORSConfig struct { AllowedOrigins []string } // CORS adds Cross-Origin Resource Sharing headers to the response. // If allowedOrigins contains "*", it allows all origins. // Otherwise, it checks if the request origin is in the allowed list. func CORSWithConfig(cfg CORSConfig) func(http.Handler) http.Handler { allowAll := false originsMap := make(map[string]bool) for _, origin := range cfg.AllowedOrigins { if origin == "*" { allowAll = true break } originsMap[strings.ToLower(origin)] = true } return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") if allowAll { w.Header().Set("Access-Control-Allow-Origin", "*") } else if origin != "" && originsMap[strings.ToLower(origin)] { w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Set("Vary", "Origin") } w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS, PATCH") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") w.Header().Set("Access-Control-Max-Age", "86400") // Handle preflight requests if r.Method == http.MethodOptions { w.WriteHeader(http.StatusOK) return } next.ServeHTTP(w, r) }) } } // CORS is a compatibility wrapper that allows all origins. // Deprecated: Use CORSWithConfig for more control. func CORS(next http.Handler) http.Handler { return CORSWithConfig(CORSConfig{AllowedOrigins: []string{"*"}})(next) }