package middleware import ( "compress/gzip" "context" "io" "net/http" "net/http/httptest" "strings" "testing" "time" "github.com/gofrs/uuid/v5" "github.com/golang-jwt/jwt/v5" ) // --- CORS Tests --- func TestCORSWithConfigAllowAll(t *testing.T) { cfg := CORSConfig{AllowedOrigins: []string{"*"}} handler := CORSWithConfig(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Origin", "https://example.com") rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Header().Get("Access-Control-Allow-Origin") != "*" { t.Errorf("expected Access-Control-Allow-Origin '*', got '%s'", rec.Header().Get("Access-Control-Allow-Origin")) } } func TestCORSWithConfigSpecificOrigins(t *testing.T) { cfg := CORSConfig{AllowedOrigins: []string{"https://allowed.com", "https://another.com"}} handler := CORSWithConfig(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) // Test allowed origin req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Origin", "https://allowed.com") rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Header().Get("Access-Control-Allow-Origin") != "https://allowed.com" { t.Errorf("expected origin 'https://allowed.com', got '%s'", rec.Header().Get("Access-Control-Allow-Origin")) } // Test blocked origin req2 := httptest.NewRequest(http.MethodGet, "/", nil) req2.Header.Set("Origin", "https://blocked.com") rec2 := httptest.NewRecorder() handler.ServeHTTP(rec2, req2) if rec2.Header().Get("Access-Control-Allow-Origin") != "" { t.Errorf("expected empty Access-Control-Allow-Origin, got '%s'", rec2.Header().Get("Access-Control-Allow-Origin")) } } func TestCORSPreflight(t *testing.T) { cfg := CORSConfig{AllowedOrigins: []string{"*"}} handler := CORSWithConfig(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusTeapot) // Should not reach here })) req := httptest.NewRequest(http.MethodOptions, "/", nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("expected status 200 for preflight, got %d", rec.Code) } } func TestCORSHeaders(t *testing.T) { cfg := CORSConfig{AllowedOrigins: []string{"*"}} handler := CORSWithConfig(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if !strings.Contains(rec.Header().Get("Access-Control-Allow-Methods"), "GET") { t.Error("expected Access-Control-Allow-Methods to include GET") } if !strings.Contains(rec.Header().Get("Access-Control-Allow-Headers"), "Authorization") { t.Error("expected Access-Control-Allow-Headers to include Authorization") } } // --- Auth Tests --- func createTestToken(secret string, userID uuid.UUID, role string, companyID *uuid.UUID) string { claims := jwt.MapClaims{ "sub": userID.String(), "role": role, "exp": time.Now().Add(time.Hour).Unix(), } if companyID != nil { claims["company_id"] = companyID.String() } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenStr, _ := token.SignedString([]byte(secret)) return tokenStr } func TestRequireAuthValidToken(t *testing.T) { secret := "test-secret" userID, _ := uuid.NewV4() companyID, _ := uuid.NewV4() tokenStr := createTestToken(secret, userID, "Admin", &companyID) var receivedClaims Claims handler := RequireAuth([]byte(secret))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { receivedClaims, _ = GetClaims(r.Context()) w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Authorization", "Bearer "+tokenStr) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("expected status 200, got %d", rec.Code) } if receivedClaims.UserID != userID { t.Errorf("expected userID %s, got %s", userID, receivedClaims.UserID) } if receivedClaims.Role != "Admin" { t.Errorf("expected role 'Admin', got '%s'", receivedClaims.Role) } } func TestRequireAuthMissingToken(t *testing.T) { handler := RequireAuth([]byte("secret"))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusUnauthorized { t.Errorf("expected status 401, got %d", rec.Code) } } func TestRequireAuthInvalidToken(t *testing.T) { handler := RequireAuth([]byte("secret"))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Authorization", "Bearer invalid-token") rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusUnauthorized { t.Errorf("expected status 401, got %d", rec.Code) } } func TestRequireAuthWrongSecret(t *testing.T) { userID, _ := uuid.NewV4() tokenStr := createTestToken("correct-secret", userID, "User", nil) handler := RequireAuth([]byte("wrong-secret"))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Authorization", "Bearer "+tokenStr) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusUnauthorized { t.Errorf("expected status 401, got %d", rec.Code) } } func TestRequireAuthRoleRestriction(t *testing.T) { secret := "secret" userID, _ := uuid.NewV4() tokenStr := createTestToken(secret, userID, "User", nil) handler := RequireAuth([]byte(secret), "Admin")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Authorization", "Bearer "+tokenStr) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusForbidden { t.Errorf("expected status 403, got %d", rec.Code) } } func TestRequireAuthRoleAllowed(t *testing.T) { secret := "secret" userID, _ := uuid.NewV4() tokenStr := createTestToken(secret, userID, "Admin", nil) handler := RequireAuth([]byte(secret), "Admin")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Authorization", "Bearer "+tokenStr) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("expected status 200, got %d", rec.Code) } } func TestGetClaimsFromContext(t *testing.T) { claims := Claims{ UserID: uuid.Must(uuid.NewV4()), Role: "Admin", } ctx := context.WithValue(context.Background(), claimsKey, claims) retrieved, ok := GetClaims(ctx) if !ok { t.Error("expected to retrieve claims from context") } if retrieved.UserID != claims.UserID { t.Errorf("expected userID %s, got %s", claims.UserID, retrieved.UserID) } } func TestGetClaimsNotInContext(t *testing.T) { _, ok := GetClaims(context.Background()) if ok { t.Error("expected claims to not be in context") } } // --- Gzip Tests --- func TestGzipCompression(t *testing.T) { handler := Gzip(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("Hello, World!")) })) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Accept-Encoding", "gzip") rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Header().Get("Content-Encoding") != "gzip" { t.Error("expected Content-Encoding 'gzip'") } // Decompress and verify reader, err := gzip.NewReader(rec.Body) if err != nil { t.Fatalf("failed to create gzip reader: %v", err) } defer reader.Close() body, err := io.ReadAll(reader) if err != nil { t.Fatalf("failed to read gzip body: %v", err) } if string(body) != "Hello, World!" { t.Errorf("expected 'Hello, World!', got '%s'", string(body)) } } func TestGzipNoCompression(t *testing.T) { handler := Gzip(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("Hello, World!")) })) req := httptest.NewRequest(http.MethodGet, "/", nil) // No Accept-Encoding header rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Header().Get("Content-Encoding") == "gzip" { t.Error("should not use gzip when not requested") } if rec.Body.String() != "Hello, World!" { t.Errorf("expected 'Hello, World!', got '%s'", rec.Body.String()) } } // --- Logger Tests --- func TestLoggerMiddleware(t *testing.T) { handler := Logger(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/test-path", nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("expected status 200, got %d", rec.Code) } } // --- CORS Legacy Wrapper Test --- func TestCORSLegacyWrapper(t *testing.T) { handler := CORS(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Origin", "https://example.com") rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Header().Get("Access-Control-Allow-Origin") != "*" { t.Errorf("expected '*', got '%s'", rec.Header().Get("Access-Control-Allow-Origin")) } } // --- Additional Auth Edge Case Tests --- func TestRequireAuthExpiredToken(t *testing.T) { secret := "test-secret" userID, _ := uuid.NewV4() // Create an expired token token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ "sub": userID.String(), "role": "Admin", "exp": time.Now().Add(-time.Hour).Unix(), // expired 1 hour ago }) tokenStr, _ := token.SignedString([]byte(secret)) handler := RequireAuth([]byte(secret))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Error("handler should not be called for expired token") })) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Authorization", "Bearer "+tokenStr) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusUnauthorized { t.Errorf("expected 401, got %d", rec.Code) } } func TestRequireAuthWrongSigningMethod(t *testing.T) { // Create a token with None signing method (should be rejected) token := jwt.NewWithClaims(jwt.SigningMethodNone, jwt.MapClaims{ "sub": "test-user-id", "role": "Admin", "exp": time.Now().Add(time.Hour).Unix(), }) tokenStr, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType) handler := RequireAuth([]byte("test-secret"))(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Error("handler should not be called") })) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Authorization", "Bearer "+tokenStr) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusUnauthorized { t.Errorf("expected 401, got %d", rec.Code) } } // --- Security Headers Tests --- func TestSecurityHeaders(t *testing.T) { handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Header().Get("X-Content-Type-Options") != "nosniff" { t.Error("expected X-Content-Type-Options: nosniff") } if rec.Header().Get("X-Frame-Options") != "DENY" { t.Error("expected X-Frame-Options: DENY") } if rec.Header().Get("X-XSS-Protection") != "1; mode=block" { t.Error("expected X-XSS-Protection: 1; mode=block") } if rec.Header().Get("Content-Security-Policy") != "default-src 'none'" { t.Error("expected Content-Security-Policy: default-src 'none'") } if rec.Header().Get("Cache-Control") != "no-store, max-age=0" { t.Error("expected Cache-Control: no-store, max-age=0") } }