package middleware import ( "context" "encoding/json" "net/http" "net/http/httptest" "os" "strings" "testing" "time" "github.com/rede5/gohorsejobs/backend/internal/utils" ) func TestLoggingMiddleware(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) mw := LoggingMiddleware(handler) req := httptest.NewRequest("GET", "/test", nil) rr := httptest.NewRecorder() mw.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("Expected status 200, got %d", rr.Code) } } func TestAuthMiddleware_Success(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { claims, ok := r.Context().Value(UserKey).(*utils.Claims) if !ok { t.Error("Claims not found in context") w.WriteHeader(http.StatusUnauthorized) return } if claims.UserID != 1 { t.Errorf("Expected userID 1, got %d", claims.UserID) } w.WriteHeader(http.StatusOK) }) mw := AuthMiddleware(handler) token, _ := utils.GenerateJWT(1, "test-user", "user") req := httptest.NewRequest("GET", "/test", nil) req.Header.Set("Authorization", "Bearer "+token) rr := httptest.NewRecorder() mw.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("Expected status 200, got %d", rr.Code) } } func TestRateLimiter_isAllowed(t *testing.T) { limiter := NewRateLimiter(3, time.Minute) // First 3 requests should be allowed for i := 0; i < 3; i++ { if !limiter.isAllowed("192.168.1.1") { t.Errorf("Request %d should be allowed", i+1) } } // 4th request should be denied if limiter.isAllowed("192.168.1.1") { t.Error("Request 4 should be denied") } // Different IP should still be allowed if !limiter.isAllowed("192.168.1.2") { t.Error("Different IP should be allowed") } } func TestRateLimitMiddleware(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) mw := RateLimitMiddleware(2, time.Minute)(handler) // Create test requests for i := 0; i < 3; i++ { req := httptest.NewRequest("GET", "/test", nil) req.RemoteAddr = "192.168.1.100:12345" rr := httptest.NewRecorder() mw.ServeHTTP(rr, req) if i < 2 { if rr.Code != http.StatusOK { t.Errorf("Request %d: expected status 200, got %d", i+1, rr.Code) } } else { if rr.Code != http.StatusTooManyRequests { t.Errorf("Request %d: expected status 429, got %d", i+1, rr.Code) } } } } func TestSecurityHeadersMiddleware(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) mw := SecurityHeadersMiddleware(handler) req := httptest.NewRequest("GET", "/test", nil) rr := httptest.NewRecorder() mw.ServeHTTP(rr, req) expectedHeaders := map[string]string{ "X-Frame-Options": "DENY", "X-Content-Type-Options": "nosniff", "X-XSS-Protection": "1; mode=block", } for header, expected := range expectedHeaders { actual := rr.Header().Get(header) if actual != expected { t.Errorf("Header %s: expected %q, got %q", header, expected, actual) } } } func TestAuthMiddleware_NoAuthHeader(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) mw := AuthMiddleware(handler) req := httptest.NewRequest("GET", "/test", nil) rr := httptest.NewRecorder() mw.ServeHTTP(rr, req) if rr.Code != http.StatusUnauthorized { t.Errorf("Expected status 401, got %d", rr.Code) } } func TestAuthMiddleware_InvalidFormat(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) mw := AuthMiddleware(handler) req := httptest.NewRequest("GET", "/test", nil) req.Header.Set("Authorization", "InvalidFormat") rr := httptest.NewRecorder() mw.ServeHTTP(rr, req) if rr.Code != http.StatusUnauthorized { t.Errorf("Expected status 401, got %d", rr.Code) } } func TestAuthMiddleware_InvalidToken(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) mw := AuthMiddleware(handler) req := httptest.NewRequest("GET", "/test", nil) req.Header.Set("Authorization", "Bearer invalid.token.here") rr := httptest.NewRecorder() mw.ServeHTTP(rr, req) if rr.Code != http.StatusUnauthorized { t.Errorf("Expected status 401, got %d", rr.Code) } } func TestRequireRole_NoClaims(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) mw := RequireRole("admin")(handler) req := httptest.NewRequest("GET", "/test", nil) rr := httptest.NewRecorder() mw.ServeHTTP(rr, req) if rr.Code != http.StatusUnauthorized { t.Errorf("Expected status 401, got %d", rr.Code) } } func TestCORSMiddleware(t *testing.T) { os.Setenv("CORS_ORIGINS", "http://allowed.com,http://another.com") defer os.Unsetenv("CORS_ORIGINS") handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) mw := CORSMiddleware(handler) // Test allowed origin req := httptest.NewRequest("OPTIONS", "/test", nil) req.Header.Set("Origin", "http://allowed.com") rr := httptest.NewRecorder() mw.ServeHTTP(rr, req) if rr.Header().Get("Access-Control-Allow-Origin") != "http://allowed.com" { t.Errorf("Expected allow origin http://allowed.com, got %s", rr.Header().Get("Access-Control-Allow-Origin")) } // Test disallowed origin req = httptest.NewRequest("OPTIONS", "/test", nil) req.Header.Set("Origin", "http://hacker.com") rr = httptest.NewRecorder() mw.ServeHTTP(rr, req) if rr.Header().Get("Access-Control-Allow-Origin") != "" { t.Errorf("Expected empty allow origin, got %s", rr.Header().Get("Access-Control-Allow-Origin")) } } func TestSanitizeMiddleware(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Read body to verify sanitization var body map[string]interface{} json.NewDecoder(r.Body).Decode(&body) w.Header().Set("X-Sanitized-Name", body["name"].(string)) w.WriteHeader(http.StatusOK) }) mw := SanitizeMiddleware(handler) jsonBody := `{"name": ""}` req := httptest.NewRequest("POST", "/test", strings.NewReader(jsonBody)) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() mw.ServeHTTP(rr, req) expected := "<script>alert('xss')</script>" if rr.Header().Get("X-Sanitized-Name") != expected { t.Errorf("Expected sanitized name %s, got %s", expected, rr.Header().Get("X-Sanitized-Name")) } } func TestRequireRole_Success(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) mw := RequireRole("admin")(handler) req := httptest.NewRequest("GET", "/test", nil) // Inject claims into context manually to simulate authenticated user claims := &utils.Claims{ UserID: 1, Role: "admin", } ctx := context.WithValue(req.Context(), UserKey, claims) rr := httptest.NewRecorder() mw.ServeHTTP(rr, req.WithContext(ctx)) if rr.Code != http.StatusOK { t.Errorf("Expected status 200, got %d", rr.Code) } } func TestRequireRole_Forbidden(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) mw := RequireRole("admin")(handler) req := httptest.NewRequest("GET", "/test", nil) claims := &utils.Claims{ UserID: 1, Role: "user", // Wrong role } ctx := context.WithValue(req.Context(), UserKey, claims) rr := httptest.NewRecorder() mw.ServeHTTP(rr, req.WithContext(ctx)) if rr.Code != http.StatusForbidden { t.Errorf("Expected status 403, got %d", rr.Code) } } func TestSanitizeMiddleware_InvalidJSON(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) mw := SanitizeMiddleware(handler) jsonBody := `{"name": "broken json` req := httptest.NewRequest("POST", "/test", strings.NewReader(jsonBody)) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() mw.ServeHTTP(rr, req) // Should pass through if JSON invalid (or handle gracefully) if rr.Code != http.StatusOK { t.Errorf("Expected status 200 (pass through), got %d", rr.Code) } }