diff --git a/requirements.txt b/requirements.txt index eba83ea..37b04de 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,31 +1,32 @@ -python-dotenv>=1.0.1 -uvicorn>=0.30.0 -fastapi>=0.111.0 +python-dotenv==1.2.1 +uvicorn==0.40.0 +fastapi==0.128.0 # Core framework -agno>=0.7.1 +agno==2.4.2 # Playground and storage -sqlalchemy>=2.0.31 -aiosqlite>=0.20.0 +sqlalchemy==2.0.46 +aiosqlite==0.22.1 # Providers used in code examples (install what you use) -openai>=1.40.0 -google-generativeai>=0.7.2 -google-cloud-aiplatform>=1.66.0 -mistralai>=1.2.4 -ollama>=0.3.3 -groq>=0.9.0 -PyJWT>=2.8.0 +openai==2.15.0 +google-generativeai==0.8.6 +google-cloud-aiplatform==1.134.0 +mistralai==1.10.1 +ollama==0.6.1 +groq==1.0.0 +PyJWT==2.10.1 # Original optimizer dependencies -anthropic>=0.25.0 -google-genai>=1.25.0 -typer[all]>=0.9.0 -structlog>=23.0.0 -oci>=2.157.0 -aiofiles>=25.1.0 -types-aiofiles>=24.1.0 -oracledb>=2.2.0 -pymssql>=2.3.0 -sqlparse>=0.5.0 +anthropic==0.76.0 +google-genai==1.60.0 +typer[all]==0.21.1 +structlog==25.5.0 +oci==2.165.1 +aiofiles==25.1.0 +types-aiofiles==25.1.0.20251011 +oracledb==3.4.1 +pymssql==2.3.11 +sqlparse==0.5.5 +sqlglot==28.6.0 diff --git a/src/sql_optimizer_team/tools/engine/analysis_tools/query_analyzer.py b/src/sql_optimizer_team/tools/engine/analysis_tools/query_analyzer.py index b75a723..6c8adf2 100644 --- a/src/sql_optimizer_team/tools/engine/analysis_tools/query_analyzer.py +++ b/src/sql_optimizer_team/tools/engine/analysis_tools/query_analyzer.py @@ -9,6 +9,22 @@ import hashlib from dataclasses import dataclass, field from typing import Optional +try: + import sqlglot # type: ignore[import-not-found] + from sqlglot import exp # type: ignore[import-not-found] +except Exception: # pragma: no cover - optional dependency at runtime + sqlglot = None # type: ignore[assignment] + exp = None # type: ignore[assignment] + + +def _parse_sql(sql_text: str): + if sqlglot is None or exp is None: + return None + try: + return sqlglot.parse_one(sql_text, error_level="ignore") + except Exception: + return None + @dataclass class QueryComplexityResult: @@ -116,15 +132,20 @@ class QueryAnalyzer: """ result = QueryComplexityResult() - # Normalize SQL for analysis - normalized = self._normalize_sql(sql_text) - - # Count elements - result.column_count = self._count_columns(normalized) - result.table_count = self._count_tables(normalized) - result.subquery_count = self._count_subqueries(normalized) - result.case_statement_count = self._count_case_statements(normalized) - result.join_count = self._count_joins(normalized) + parsed = _parse_sql(sql_text) + if parsed is not None: + result.column_count = self._count_columns_ast(parsed) + result.table_count = self._count_tables_ast(parsed) + result.subquery_count = self._count_subqueries_ast(parsed) + result.case_statement_count = self._count_case_statements_ast(parsed) + result.join_count = self._count_joins_ast(parsed) + else: + normalized = self._normalize_sql(sql_text) + result.column_count = self._count_columns(normalized) + result.table_count = self._count_tables(normalized) + result.subquery_count = self._count_subqueries(normalized) + result.case_statement_count = self._count_case_statements(normalized) + result.join_count = self._count_joins(normalized) # Calculate complexity score result.complexity_score = ( @@ -180,6 +201,22 @@ class QueryAnalyzer: # Normalize whitespace sql = re.sub(r'\s+', ' ', sql) return sql.upper().strip() + + def _count_columns_ast(self, parsed) -> int: + select = parsed.find(exp.Select) if exp is not None else None + return len(select.expressions) if select else 0 + + def _count_tables_ast(self, parsed) -> int: + return len({t.name for t in parsed.find_all(exp.Table) if t.name}) + + def _count_subqueries_ast(self, parsed) -> int: + return max(sum(1 for _ in parsed.find_all(exp.Select)) - 1, 0) + + def _count_case_statements_ast(self, parsed) -> int: + return sum(1 for _ in parsed.find_all(exp.Case)) + + def _count_joins_ast(self, parsed) -> int: + return sum(1 for _ in parsed.find_all(exp.Join)) def _count_columns(self, sql: str) -> int: """Count approximate number of columns in SELECT.""" @@ -356,29 +393,38 @@ class OptimizationValidator: def _extract_tables(self, sql: str) -> set[str]: """Extract table names from SQL.""" + parsed = _parse_sql(sql) + if parsed is not None and exp is not None: + return {t.name.lower() for t in parsed.find_all(exp.Table) if t.name} + tables: set[str] = set() normalized = sql.upper() - - # FROM table from_matches = re.findall(r'FROM\s+([#\w]+)', normalized) tables.update(t.lower() for t in from_matches) - - # JOIN table join_matches = re.findall(r'JOIN\s+([#\w]+)', normalized) tables.update(t.lower() for t in join_matches) - return tables def _extract_column_aliases(self, sql: str) -> set[str]: """Extract column aliases from SELECT clause.""" - aliases: set[str] = set() - - # Pattern: column AS alias or alias = expression - as_pattern = re.findall(r'\bAS\s+(\w+)', sql, re.IGNORECASE) - aliases.update(a.lower() for a in as_pattern) - - equals_pattern = re.findall(r'^\s*(\w+)\s*=', sql, re.MULTILINE | re.IGNORECASE) - aliases.update(a.lower() for a in equals_pattern) - - return aliases + parsed = _parse_sql(sql) + if parsed is not None and exp is not None: + select = parsed.find(exp.Select) + if not select: + return set() + aliases: set[str] = set() + for projection in select.expressions: + if isinstance(projection, exp.Star): + continue + alias = projection.alias_or_name + if alias: + aliases.add(alias.lower()) + return aliases + + fallback_aliases: set[str] = set() + as_pattern = re.findall(r'\bAS\s+(\w+)', sql, re.IGNORECASE) + fallback_aliases.update(a.lower() for a in as_pattern) + equals_pattern = re.findall(r'^\s*(\w+)\s*=', sql, re.MULTILINE | re.IGNORECASE) + fallback_aliases.update(a.lower() for a in equals_pattern) + return fallback_aliases diff --git a/src/sql_optimizer_team/tools/engine/storage_tools/query_history.py b/src/sql_optimizer_team/tools/engine/storage_tools/query_history.py index 32d0871..ec3efa1 100644 --- a/src/sql_optimizer_team/tools/engine/storage_tools/query_history.py +++ b/src/sql_optimizer_team/tools/engine/storage_tools/query_history.py @@ -17,6 +17,11 @@ from typing import Any import sqlparse +try: + import sqlglot # type: ignore[import-not-found] +except Exception: # pragma: no cover - optional dependency at runtime + sqlglot = None # type: ignore[assignment] + def _format_sql(sql_text: str) -> str: """Format SQL text with proper indentation and line breaks. @@ -27,16 +32,24 @@ def _format_sql(sql_text: str) -> str: Returns: Formatted SQL with indentation """ - cleaned = sql_text.replace('\r\n', '\n').replace('\r', '\n') - + cleaned = sql_text.replace("\r\n", "\n").replace("\r", "\n") + + if sqlglot is not None: + try: + parsed = sqlglot.parse_one(cleaned, error_level="ignore") + if parsed is not None: + return parsed.sql(pretty=True).strip() + except Exception: + pass + formatted = sqlparse.format( cleaned, reindent=True, - keyword_case='upper', + keyword_case="upper", indent_width=4, wrap_after=80, ) - + return formatted.strip() @@ -423,12 +436,12 @@ class QueryHistoryManager: lines.extend([ "", - f"-- -----------------------------------------------------------------------------", + "-- -----------------------------------------------------------------------------", f"-- Query #{idx}", f"-- Hash: {query_hash[:16]}...", f"-- First Seen: {first_seen}", f"-- Times Seen: {times_seen}", - f"-- -----------------------------------------------------------------------------", + "-- -----------------------------------------------------------------------------", "", formatted_sql, "", diff --git a/src/sql_optimizer_team/tools/sql_tools.py b/src/sql_optimizer_team/tools/sql_tools.py index 8cc6ebc..d18848f 100644 --- a/src/sql_optimizer_team/tools/sql_tools.py +++ b/src/sql_optimizer_team/tools/sql_tools.py @@ -8,6 +8,13 @@ from pathlib import Path from typing import Iterable import difflib +try: + import sqlglot # type: ignore[import-not-found] + from sqlglot import exp # type: ignore[import-not-found] +except Exception: # pragma: no cover - optional dependency at runtime + sqlglot = None # type: ignore[assignment] + exp = None # type: ignore[assignment] + @dataclass(frozen=True) class QueryComplexityResult: @@ -117,7 +124,7 @@ def _find_main_select_list(sql: str) -> str | None: return None -def compute_query_complexity(sql: str) -> QueryComplexityResult: +def _compute_query_complexity_fallback(sql: str) -> QueryComplexityResult: normalized = _strip_comments(sql) lowered = normalized.lower() @@ -149,6 +156,42 @@ def compute_query_complexity(sql: str) -> QueryComplexityResult: ) +def compute_query_complexity(sql: str) -> QueryComplexityResult: + if sqlglot is None or exp is None: + return _compute_query_complexity_fallback(sql) + + try: + parsed = sqlglot.parse_one(sql, error_level="ignore") + if parsed is None: + return _compute_query_complexity_fallback(sql) + + select = parsed.find(exp.Select) + column_count = len(select.expressions) if select else 0 + table_count = len({t.name for t in parsed.find_all(exp.Table) if t.name}) + join_count = sum(1 for _ in parsed.find_all(exp.Join)) + case_count = sum(1 for _ in parsed.find_all(exp.Case)) + subquery_count = max(sum(1 for _ in parsed.find_all(exp.Select)) - 1, 0) + + score = table_count + join_count + subquery_count + case_count + if score >= 10: + complexity = "high" + elif score >= 5: + complexity = "medium" + else: + complexity = "low" + + return QueryComplexityResult( + column_count=column_count, + table_count=table_count, + subquery_count=subquery_count, + case_statement_count=case_count, + join_count=join_count, + complexity_level=complexity, + ) + except Exception: + return _compute_query_complexity_fallback(sql) + + def load_sql_from_file(path: str) -> str: file_path = Path(path).expanduser().resolve() return file_path.read_text(encoding="utf-8")