feat: Integrate sqlglot for enhanced SQL parsing and complexity analysis in query tools
This commit is contained in:
parent
a7afdfac8b
commit
45034f4cbd
4 changed files with 157 additions and 54 deletions
|
|
@ -1,31 +1,32 @@
|
|||
python-dotenv>=1.0.1
|
||||
uvicorn>=0.30.0
|
||||
fastapi>=0.111.0
|
||||
python-dotenv==1.2.1
|
||||
uvicorn==0.40.0
|
||||
fastapi==0.128.0
|
||||
|
||||
# Core framework
|
||||
agno>=0.7.1
|
||||
agno==2.4.2
|
||||
|
||||
# Playground and storage
|
||||
sqlalchemy>=2.0.31
|
||||
aiosqlite>=0.20.0
|
||||
sqlalchemy==2.0.46
|
||||
aiosqlite==0.22.1
|
||||
|
||||
# Providers used in code examples (install what you use)
|
||||
openai>=1.40.0
|
||||
google-generativeai>=0.7.2
|
||||
google-cloud-aiplatform>=1.66.0
|
||||
mistralai>=1.2.4
|
||||
ollama>=0.3.3
|
||||
groq>=0.9.0
|
||||
PyJWT>=2.8.0
|
||||
openai==2.15.0
|
||||
google-generativeai==0.8.6
|
||||
google-cloud-aiplatform==1.134.0
|
||||
mistralai==1.10.1
|
||||
ollama==0.6.1
|
||||
groq==1.0.0
|
||||
PyJWT==2.10.1
|
||||
|
||||
# Original optimizer dependencies
|
||||
anthropic>=0.25.0
|
||||
google-genai>=1.25.0
|
||||
typer[all]>=0.9.0
|
||||
structlog>=23.0.0
|
||||
oci>=2.157.0
|
||||
aiofiles>=25.1.0
|
||||
types-aiofiles>=24.1.0
|
||||
oracledb>=2.2.0
|
||||
pymssql>=2.3.0
|
||||
sqlparse>=0.5.0
|
||||
anthropic==0.76.0
|
||||
google-genai==1.60.0
|
||||
typer[all]==0.21.1
|
||||
structlog==25.5.0
|
||||
oci==2.165.1
|
||||
aiofiles==25.1.0
|
||||
types-aiofiles==25.1.0.20251011
|
||||
oracledb==3.4.1
|
||||
pymssql==2.3.11
|
||||
sqlparse==0.5.5
|
||||
sqlglot==28.6.0
|
||||
|
|
|
|||
|
|
@ -9,6 +9,22 @@ import hashlib
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
try:
|
||||
import sqlglot # type: ignore[import-not-found]
|
||||
from sqlglot import exp # type: ignore[import-not-found]
|
||||
except Exception: # pragma: no cover - optional dependency at runtime
|
||||
sqlglot = None # type: ignore[assignment]
|
||||
exp = None # type: ignore[assignment]
|
||||
|
||||
|
||||
def _parse_sql(sql_text: str):
|
||||
if sqlglot is None or exp is None:
|
||||
return None
|
||||
try:
|
||||
return sqlglot.parse_one(sql_text, error_level="ignore")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryComplexityResult:
|
||||
|
|
@ -116,15 +132,20 @@ class QueryAnalyzer:
|
|||
"""
|
||||
result = QueryComplexityResult()
|
||||
|
||||
# Normalize SQL for analysis
|
||||
normalized = self._normalize_sql(sql_text)
|
||||
|
||||
# Count elements
|
||||
result.column_count = self._count_columns(normalized)
|
||||
result.table_count = self._count_tables(normalized)
|
||||
result.subquery_count = self._count_subqueries(normalized)
|
||||
result.case_statement_count = self._count_case_statements(normalized)
|
||||
result.join_count = self._count_joins(normalized)
|
||||
parsed = _parse_sql(sql_text)
|
||||
if parsed is not None:
|
||||
result.column_count = self._count_columns_ast(parsed)
|
||||
result.table_count = self._count_tables_ast(parsed)
|
||||
result.subquery_count = self._count_subqueries_ast(parsed)
|
||||
result.case_statement_count = self._count_case_statements_ast(parsed)
|
||||
result.join_count = self._count_joins_ast(parsed)
|
||||
else:
|
||||
normalized = self._normalize_sql(sql_text)
|
||||
result.column_count = self._count_columns(normalized)
|
||||
result.table_count = self._count_tables(normalized)
|
||||
result.subquery_count = self._count_subqueries(normalized)
|
||||
result.case_statement_count = self._count_case_statements(normalized)
|
||||
result.join_count = self._count_joins(normalized)
|
||||
|
||||
# Calculate complexity score
|
||||
result.complexity_score = (
|
||||
|
|
@ -180,6 +201,22 @@ class QueryAnalyzer:
|
|||
# Normalize whitespace
|
||||
sql = re.sub(r'\s+', ' ', sql)
|
||||
return sql.upper().strip()
|
||||
|
||||
def _count_columns_ast(self, parsed) -> int:
|
||||
select = parsed.find(exp.Select) if exp is not None else None
|
||||
return len(select.expressions) if select else 0
|
||||
|
||||
def _count_tables_ast(self, parsed) -> int:
|
||||
return len({t.name for t in parsed.find_all(exp.Table) if t.name})
|
||||
|
||||
def _count_subqueries_ast(self, parsed) -> int:
|
||||
return max(sum(1 for _ in parsed.find_all(exp.Select)) - 1, 0)
|
||||
|
||||
def _count_case_statements_ast(self, parsed) -> int:
|
||||
return sum(1 for _ in parsed.find_all(exp.Case))
|
||||
|
||||
def _count_joins_ast(self, parsed) -> int:
|
||||
return sum(1 for _ in parsed.find_all(exp.Join))
|
||||
|
||||
def _count_columns(self, sql: str) -> int:
|
||||
"""Count approximate number of columns in SELECT."""
|
||||
|
|
@ -356,29 +393,38 @@ class OptimizationValidator:
|
|||
|
||||
def _extract_tables(self, sql: str) -> set[str]:
|
||||
"""Extract table names from SQL."""
|
||||
parsed = _parse_sql(sql)
|
||||
if parsed is not None and exp is not None:
|
||||
return {t.name.lower() for t in parsed.find_all(exp.Table) if t.name}
|
||||
|
||||
tables: set[str] = set()
|
||||
normalized = sql.upper()
|
||||
|
||||
# FROM table
|
||||
from_matches = re.findall(r'FROM\s+([#\w]+)', normalized)
|
||||
tables.update(t.lower() for t in from_matches)
|
||||
|
||||
# JOIN table
|
||||
join_matches = re.findall(r'JOIN\s+([#\w]+)', normalized)
|
||||
tables.update(t.lower() for t in join_matches)
|
||||
|
||||
return tables
|
||||
|
||||
def _extract_column_aliases(self, sql: str) -> set[str]:
|
||||
"""Extract column aliases from SELECT clause."""
|
||||
aliases: set[str] = set()
|
||||
|
||||
# Pattern: column AS alias or alias = expression
|
||||
as_pattern = re.findall(r'\bAS\s+(\w+)', sql, re.IGNORECASE)
|
||||
aliases.update(a.lower() for a in as_pattern)
|
||||
|
||||
equals_pattern = re.findall(r'^\s*(\w+)\s*=', sql, re.MULTILINE | re.IGNORECASE)
|
||||
aliases.update(a.lower() for a in equals_pattern)
|
||||
|
||||
return aliases
|
||||
parsed = _parse_sql(sql)
|
||||
if parsed is not None and exp is not None:
|
||||
select = parsed.find(exp.Select)
|
||||
if not select:
|
||||
return set()
|
||||
aliases: set[str] = set()
|
||||
for projection in select.expressions:
|
||||
if isinstance(projection, exp.Star):
|
||||
continue
|
||||
alias = projection.alias_or_name
|
||||
if alias:
|
||||
aliases.add(alias.lower())
|
||||
return aliases
|
||||
|
||||
fallback_aliases: set[str] = set()
|
||||
as_pattern = re.findall(r'\bAS\s+(\w+)', sql, re.IGNORECASE)
|
||||
fallback_aliases.update(a.lower() for a in as_pattern)
|
||||
equals_pattern = re.findall(r'^\s*(\w+)\s*=', sql, re.MULTILINE | re.IGNORECASE)
|
||||
fallback_aliases.update(a.lower() for a in equals_pattern)
|
||||
return fallback_aliases
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,11 @@ from typing import Any
|
|||
|
||||
import sqlparse
|
||||
|
||||
try:
|
||||
import sqlglot # type: ignore[import-not-found]
|
||||
except Exception: # pragma: no cover - optional dependency at runtime
|
||||
sqlglot = None # type: ignore[assignment]
|
||||
|
||||
|
||||
def _format_sql(sql_text: str) -> str:
|
||||
"""Format SQL text with proper indentation and line breaks.
|
||||
|
|
@ -27,16 +32,24 @@ def _format_sql(sql_text: str) -> str:
|
|||
Returns:
|
||||
Formatted SQL with indentation
|
||||
"""
|
||||
cleaned = sql_text.replace('\r\n', '\n').replace('\r', '\n')
|
||||
|
||||
cleaned = sql_text.replace("\r\n", "\n").replace("\r", "\n")
|
||||
|
||||
if sqlglot is not None:
|
||||
try:
|
||||
parsed = sqlglot.parse_one(cleaned, error_level="ignore")
|
||||
if parsed is not None:
|
||||
return parsed.sql(pretty=True).strip()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
formatted = sqlparse.format(
|
||||
cleaned,
|
||||
reindent=True,
|
||||
keyword_case='upper',
|
||||
keyword_case="upper",
|
||||
indent_width=4,
|
||||
wrap_after=80,
|
||||
)
|
||||
|
||||
|
||||
return formatted.strip()
|
||||
|
||||
|
||||
|
|
@ -423,12 +436,12 @@ class QueryHistoryManager:
|
|||
|
||||
lines.extend([
|
||||
"",
|
||||
f"-- -----------------------------------------------------------------------------",
|
||||
"-- -----------------------------------------------------------------------------",
|
||||
f"-- Query #{idx}",
|
||||
f"-- Hash: {query_hash[:16]}...",
|
||||
f"-- First Seen: {first_seen}",
|
||||
f"-- Times Seen: {times_seen}",
|
||||
f"-- -----------------------------------------------------------------------------",
|
||||
"-- -----------------------------------------------------------------------------",
|
||||
"",
|
||||
formatted_sql,
|
||||
"",
|
||||
|
|
|
|||
|
|
@ -8,6 +8,13 @@ from pathlib import Path
|
|||
from typing import Iterable
|
||||
import difflib
|
||||
|
||||
try:
|
||||
import sqlglot # type: ignore[import-not-found]
|
||||
from sqlglot import exp # type: ignore[import-not-found]
|
||||
except Exception: # pragma: no cover - optional dependency at runtime
|
||||
sqlglot = None # type: ignore[assignment]
|
||||
exp = None # type: ignore[assignment]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class QueryComplexityResult:
|
||||
|
|
@ -117,7 +124,7 @@ def _find_main_select_list(sql: str) -> str | None:
|
|||
return None
|
||||
|
||||
|
||||
def compute_query_complexity(sql: str) -> QueryComplexityResult:
|
||||
def _compute_query_complexity_fallback(sql: str) -> QueryComplexityResult:
|
||||
normalized = _strip_comments(sql)
|
||||
lowered = normalized.lower()
|
||||
|
||||
|
|
@ -149,6 +156,42 @@ def compute_query_complexity(sql: str) -> QueryComplexityResult:
|
|||
)
|
||||
|
||||
|
||||
def compute_query_complexity(sql: str) -> QueryComplexityResult:
|
||||
if sqlglot is None or exp is None:
|
||||
return _compute_query_complexity_fallback(sql)
|
||||
|
||||
try:
|
||||
parsed = sqlglot.parse_one(sql, error_level="ignore")
|
||||
if parsed is None:
|
||||
return _compute_query_complexity_fallback(sql)
|
||||
|
||||
select = parsed.find(exp.Select)
|
||||
column_count = len(select.expressions) if select else 0
|
||||
table_count = len({t.name for t in parsed.find_all(exp.Table) if t.name})
|
||||
join_count = sum(1 for _ in parsed.find_all(exp.Join))
|
||||
case_count = sum(1 for _ in parsed.find_all(exp.Case))
|
||||
subquery_count = max(sum(1 for _ in parsed.find_all(exp.Select)) - 1, 0)
|
||||
|
||||
score = table_count + join_count + subquery_count + case_count
|
||||
if score >= 10:
|
||||
complexity = "high"
|
||||
elif score >= 5:
|
||||
complexity = "medium"
|
||||
else:
|
||||
complexity = "low"
|
||||
|
||||
return QueryComplexityResult(
|
||||
column_count=column_count,
|
||||
table_count=table_count,
|
||||
subquery_count=subquery_count,
|
||||
case_statement_count=case_count,
|
||||
join_count=join_count,
|
||||
complexity_level=complexity,
|
||||
)
|
||||
except Exception:
|
||||
return _compute_query_complexity_fallback(sql)
|
||||
|
||||
|
||||
def load_sql_from_file(path: str) -> str:
|
||||
file_path = Path(path).expanduser().resolve()
|
||||
return file_path.read_text(encoding="utf-8")
|
||||
|
|
|
|||
Loading…
Reference in a new issue