feat: Integrate sqlglot for enhanced SQL parsing and complexity analysis in query tools

This commit is contained in:
william.dias 2026-01-23 10:14:48 -03:00
parent a7afdfac8b
commit 45034f4cbd
4 changed files with 157 additions and 54 deletions

View file

@ -1,31 +1,32 @@
python-dotenv>=1.0.1
uvicorn>=0.30.0
fastapi>=0.111.0
python-dotenv==1.2.1
uvicorn==0.40.0
fastapi==0.128.0
# Core framework
agno>=0.7.1
agno==2.4.2
# Playground and storage
sqlalchemy>=2.0.31
aiosqlite>=0.20.0
sqlalchemy==2.0.46
aiosqlite==0.22.1
# Providers used in code examples (install what you use)
openai>=1.40.0
google-generativeai>=0.7.2
google-cloud-aiplatform>=1.66.0
mistralai>=1.2.4
ollama>=0.3.3
groq>=0.9.0
PyJWT>=2.8.0
openai==2.15.0
google-generativeai==0.8.6
google-cloud-aiplatform==1.134.0
mistralai==1.10.1
ollama==0.6.1
groq==1.0.0
PyJWT==2.10.1
# Original optimizer dependencies
anthropic>=0.25.0
google-genai>=1.25.0
typer[all]>=0.9.0
structlog>=23.0.0
oci>=2.157.0
aiofiles>=25.1.0
types-aiofiles>=24.1.0
oracledb>=2.2.0
pymssql>=2.3.0
sqlparse>=0.5.0
anthropic==0.76.0
google-genai==1.60.0
typer[all]==0.21.1
structlog==25.5.0
oci==2.165.1
aiofiles==25.1.0
types-aiofiles==25.1.0.20251011
oracledb==3.4.1
pymssql==2.3.11
sqlparse==0.5.5
sqlglot==28.6.0

View file

@ -9,6 +9,22 @@ import hashlib
from dataclasses import dataclass, field
from typing import Optional
try:
import sqlglot # type: ignore[import-not-found]
from sqlglot import exp # type: ignore[import-not-found]
except Exception: # pragma: no cover - optional dependency at runtime
sqlglot = None # type: ignore[assignment]
exp = None # type: ignore[assignment]
def _parse_sql(sql_text: str):
if sqlglot is None or exp is None:
return None
try:
return sqlglot.parse_one(sql_text, error_level="ignore")
except Exception:
return None
@dataclass
class QueryComplexityResult:
@ -116,10 +132,15 @@ class QueryAnalyzer:
"""
result = QueryComplexityResult()
# Normalize SQL for analysis
parsed = _parse_sql(sql_text)
if parsed is not None:
result.column_count = self._count_columns_ast(parsed)
result.table_count = self._count_tables_ast(parsed)
result.subquery_count = self._count_subqueries_ast(parsed)
result.case_statement_count = self._count_case_statements_ast(parsed)
result.join_count = self._count_joins_ast(parsed)
else:
normalized = self._normalize_sql(sql_text)
# Count elements
result.column_count = self._count_columns(normalized)
result.table_count = self._count_tables(normalized)
result.subquery_count = self._count_subqueries(normalized)
@ -181,6 +202,22 @@ class QueryAnalyzer:
sql = re.sub(r'\s+', ' ', sql)
return sql.upper().strip()
def _count_columns_ast(self, parsed) -> int:
select = parsed.find(exp.Select) if exp is not None else None
return len(select.expressions) if select else 0
def _count_tables_ast(self, parsed) -> int:
return len({t.name for t in parsed.find_all(exp.Table) if t.name})
def _count_subqueries_ast(self, parsed) -> int:
return max(sum(1 for _ in parsed.find_all(exp.Select)) - 1, 0)
def _count_case_statements_ast(self, parsed) -> int:
return sum(1 for _ in parsed.find_all(exp.Case))
def _count_joins_ast(self, parsed) -> int:
return sum(1 for _ in parsed.find_all(exp.Join))
def _count_columns(self, sql: str) -> int:
"""Count approximate number of columns in SELECT."""
# Find SELECT ... FROM
@ -356,29 +393,38 @@ class OptimizationValidator:
def _extract_tables(self, sql: str) -> set[str]:
"""Extract table names from SQL."""
parsed = _parse_sql(sql)
if parsed is not None and exp is not None:
return {t.name.lower() for t in parsed.find_all(exp.Table) if t.name}
tables: set[str] = set()
normalized = sql.upper()
# FROM table
from_matches = re.findall(r'FROM\s+([#\w]+)', normalized)
tables.update(t.lower() for t in from_matches)
# JOIN table
join_matches = re.findall(r'JOIN\s+([#\w]+)', normalized)
tables.update(t.lower() for t in join_matches)
return tables
def _extract_column_aliases(self, sql: str) -> set[str]:
"""Extract column aliases from SELECT clause."""
parsed = _parse_sql(sql)
if parsed is not None and exp is not None:
select = parsed.find(exp.Select)
if not select:
return set()
aliases: set[str] = set()
# Pattern: column AS alias or alias = expression
as_pattern = re.findall(r'\bAS\s+(\w+)', sql, re.IGNORECASE)
aliases.update(a.lower() for a in as_pattern)
equals_pattern = re.findall(r'^\s*(\w+)\s*=', sql, re.MULTILINE | re.IGNORECASE)
aliases.update(a.lower() for a in equals_pattern)
for projection in select.expressions:
if isinstance(projection, exp.Star):
continue
alias = projection.alias_or_name
if alias:
aliases.add(alias.lower())
return aliases
fallback_aliases: set[str] = set()
as_pattern = re.findall(r'\bAS\s+(\w+)', sql, re.IGNORECASE)
fallback_aliases.update(a.lower() for a in as_pattern)
equals_pattern = re.findall(r'^\s*(\w+)\s*=', sql, re.MULTILINE | re.IGNORECASE)
fallback_aliases.update(a.lower() for a in equals_pattern)
return fallback_aliases

View file

@ -17,6 +17,11 @@ from typing import Any
import sqlparse
try:
import sqlglot # type: ignore[import-not-found]
except Exception: # pragma: no cover - optional dependency at runtime
sqlglot = None # type: ignore[assignment]
def _format_sql(sql_text: str) -> str:
"""Format SQL text with proper indentation and line breaks.
@ -27,12 +32,20 @@ def _format_sql(sql_text: str) -> str:
Returns:
Formatted SQL with indentation
"""
cleaned = sql_text.replace('\r\n', '\n').replace('\r', '\n')
cleaned = sql_text.replace("\r\n", "\n").replace("\r", "\n")
if sqlglot is not None:
try:
parsed = sqlglot.parse_one(cleaned, error_level="ignore")
if parsed is not None:
return parsed.sql(pretty=True).strip()
except Exception:
pass
formatted = sqlparse.format(
cleaned,
reindent=True,
keyword_case='upper',
keyword_case="upper",
indent_width=4,
wrap_after=80,
)
@ -423,12 +436,12 @@ class QueryHistoryManager:
lines.extend([
"",
f"-- -----------------------------------------------------------------------------",
"-- -----------------------------------------------------------------------------",
f"-- Query #{idx}",
f"-- Hash: {query_hash[:16]}...",
f"-- First Seen: {first_seen}",
f"-- Times Seen: {times_seen}",
f"-- -----------------------------------------------------------------------------",
"-- -----------------------------------------------------------------------------",
"",
formatted_sql,
"",

View file

@ -8,6 +8,13 @@ from pathlib import Path
from typing import Iterable
import difflib
try:
import sqlglot # type: ignore[import-not-found]
from sqlglot import exp # type: ignore[import-not-found]
except Exception: # pragma: no cover - optional dependency at runtime
sqlglot = None # type: ignore[assignment]
exp = None # type: ignore[assignment]
@dataclass(frozen=True)
class QueryComplexityResult:
@ -117,7 +124,7 @@ def _find_main_select_list(sql: str) -> str | None:
return None
def compute_query_complexity(sql: str) -> QueryComplexityResult:
def _compute_query_complexity_fallback(sql: str) -> QueryComplexityResult:
normalized = _strip_comments(sql)
lowered = normalized.lower()
@ -149,6 +156,42 @@ def compute_query_complexity(sql: str) -> QueryComplexityResult:
)
def compute_query_complexity(sql: str) -> QueryComplexityResult:
if sqlglot is None or exp is None:
return _compute_query_complexity_fallback(sql)
try:
parsed = sqlglot.parse_one(sql, error_level="ignore")
if parsed is None:
return _compute_query_complexity_fallback(sql)
select = parsed.find(exp.Select)
column_count = len(select.expressions) if select else 0
table_count = len({t.name for t in parsed.find_all(exp.Table) if t.name})
join_count = sum(1 for _ in parsed.find_all(exp.Join))
case_count = sum(1 for _ in parsed.find_all(exp.Case))
subquery_count = max(sum(1 for _ in parsed.find_all(exp.Select)) - 1, 0)
score = table_count + join_count + subquery_count + case_count
if score >= 10:
complexity = "high"
elif score >= 5:
complexity = "medium"
else:
complexity = "low"
return QueryComplexityResult(
column_count=column_count,
table_count=table_count,
subquery_count=subquery_count,
case_statement_count=case_count,
join_count=join_count,
complexity_level=complexity,
)
except Exception:
return _compute_query_complexity_fallback(sql)
def load_sql_from_file(path: str) -> str:
file_path = Path(path).expanduser().resolve()
return file_path.read_text(encoding="utf-8")