package middleware import ( "net/http" "net/http/httptest" "testing" ) func TestCORSWithConfig_AllowsAllOrigins(t *testing.T) { middleware := CORSWithConfig(CORSConfig{AllowedOrigins: []string{"*"}}) handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if got := rec.Header().Get("Access-Control-Allow-Origin"); got != "*" { t.Errorf("expected allow origin '*', got %q", got) } } func TestCORSWithConfig_AllowsMatchingOrigin(t *testing.T) { middleware := CORSWithConfig(CORSConfig{AllowedOrigins: []string{"https://example.com"}}) handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Origin", "https://example.com") rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if got := rec.Header().Get("Access-Control-Allow-Origin"); got != "https://example.com" { t.Errorf("expected allow origin to match request, got %q", got) } if got := rec.Header().Get("Vary"); got != "Origin" { t.Errorf("expected Vary header Origin, got %q", got) } } func TestCORSWithConfig_BlocksUnknownOrigin(t *testing.T) { middleware := CORSWithConfig(CORSConfig{AllowedOrigins: []string{"https://example.com"}}) handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Origin", "https://unknown.com") rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if got := rec.Header().Get("Access-Control-Allow-Origin"); got != "" { t.Errorf("expected no allow origin header, got %q", got) } } func TestCORSWithConfig_OptionsPreflight(t *testing.T) { called := false middleware := CORSWithConfig(CORSConfig{AllowedOrigins: []string{"*"}}) handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { called = true w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodOptions, "/", nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("expected 200 for preflight, got %d", rec.Code) } if called { t.Error("expected handler not to be called for preflight") } } func TestCORSWrapperAllowsAll(t *testing.T) { handler := CORS(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if got := rec.Header().Get("Access-Control-Allow-Origin"); got != "*" { t.Errorf("expected allow origin '*', got %q", got) } }