feat: Integrate sqlglot for enhanced SQL parsing and complexity analysis in query tools
This commit is contained in:
parent
a7afdfac8b
commit
45034f4cbd
4 changed files with 157 additions and 54 deletions
|
|
@ -1,31 +1,32 @@
|
||||||
python-dotenv>=1.0.1
|
python-dotenv==1.2.1
|
||||||
uvicorn>=0.30.0
|
uvicorn==0.40.0
|
||||||
fastapi>=0.111.0
|
fastapi==0.128.0
|
||||||
|
|
||||||
# Core framework
|
# Core framework
|
||||||
agno>=0.7.1
|
agno==2.4.2
|
||||||
|
|
||||||
# Playground and storage
|
# Playground and storage
|
||||||
sqlalchemy>=2.0.31
|
sqlalchemy==2.0.46
|
||||||
aiosqlite>=0.20.0
|
aiosqlite==0.22.1
|
||||||
|
|
||||||
# Providers used in code examples (install what you use)
|
# Providers used in code examples (install what you use)
|
||||||
openai>=1.40.0
|
openai==2.15.0
|
||||||
google-generativeai>=0.7.2
|
google-generativeai==0.8.6
|
||||||
google-cloud-aiplatform>=1.66.0
|
google-cloud-aiplatform==1.134.0
|
||||||
mistralai>=1.2.4
|
mistralai==1.10.1
|
||||||
ollama>=0.3.3
|
ollama==0.6.1
|
||||||
groq>=0.9.0
|
groq==1.0.0
|
||||||
PyJWT>=2.8.0
|
PyJWT==2.10.1
|
||||||
|
|
||||||
# Original optimizer dependencies
|
# Original optimizer dependencies
|
||||||
anthropic>=0.25.0
|
anthropic==0.76.0
|
||||||
google-genai>=1.25.0
|
google-genai==1.60.0
|
||||||
typer[all]>=0.9.0
|
typer[all]==0.21.1
|
||||||
structlog>=23.0.0
|
structlog==25.5.0
|
||||||
oci>=2.157.0
|
oci==2.165.1
|
||||||
aiofiles>=25.1.0
|
aiofiles==25.1.0
|
||||||
types-aiofiles>=24.1.0
|
types-aiofiles==25.1.0.20251011
|
||||||
oracledb>=2.2.0
|
oracledb==3.4.1
|
||||||
pymssql>=2.3.0
|
pymssql==2.3.11
|
||||||
sqlparse>=0.5.0
|
sqlparse==0.5.5
|
||||||
|
sqlglot==28.6.0
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,22 @@ import hashlib
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
try:
|
||||||
|
import sqlglot # type: ignore[import-not-found]
|
||||||
|
from sqlglot import exp # type: ignore[import-not-found]
|
||||||
|
except Exception: # pragma: no cover - optional dependency at runtime
|
||||||
|
sqlglot = None # type: ignore[assignment]
|
||||||
|
exp = None # type: ignore[assignment]
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_sql(sql_text: str):
|
||||||
|
if sqlglot is None or exp is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return sqlglot.parse_one(sql_text, error_level="ignore")
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class QueryComplexityResult:
|
class QueryComplexityResult:
|
||||||
|
|
@ -116,10 +132,15 @@ class QueryAnalyzer:
|
||||||
"""
|
"""
|
||||||
result = QueryComplexityResult()
|
result = QueryComplexityResult()
|
||||||
|
|
||||||
# Normalize SQL for analysis
|
parsed = _parse_sql(sql_text)
|
||||||
|
if parsed is not None:
|
||||||
|
result.column_count = self._count_columns_ast(parsed)
|
||||||
|
result.table_count = self._count_tables_ast(parsed)
|
||||||
|
result.subquery_count = self._count_subqueries_ast(parsed)
|
||||||
|
result.case_statement_count = self._count_case_statements_ast(parsed)
|
||||||
|
result.join_count = self._count_joins_ast(parsed)
|
||||||
|
else:
|
||||||
normalized = self._normalize_sql(sql_text)
|
normalized = self._normalize_sql(sql_text)
|
||||||
|
|
||||||
# Count elements
|
|
||||||
result.column_count = self._count_columns(normalized)
|
result.column_count = self._count_columns(normalized)
|
||||||
result.table_count = self._count_tables(normalized)
|
result.table_count = self._count_tables(normalized)
|
||||||
result.subquery_count = self._count_subqueries(normalized)
|
result.subquery_count = self._count_subqueries(normalized)
|
||||||
|
|
@ -181,6 +202,22 @@ class QueryAnalyzer:
|
||||||
sql = re.sub(r'\s+', ' ', sql)
|
sql = re.sub(r'\s+', ' ', sql)
|
||||||
return sql.upper().strip()
|
return sql.upper().strip()
|
||||||
|
|
||||||
|
def _count_columns_ast(self, parsed) -> int:
|
||||||
|
select = parsed.find(exp.Select) if exp is not None else None
|
||||||
|
return len(select.expressions) if select else 0
|
||||||
|
|
||||||
|
def _count_tables_ast(self, parsed) -> int:
|
||||||
|
return len({t.name for t in parsed.find_all(exp.Table) if t.name})
|
||||||
|
|
||||||
|
def _count_subqueries_ast(self, parsed) -> int:
|
||||||
|
return max(sum(1 for _ in parsed.find_all(exp.Select)) - 1, 0)
|
||||||
|
|
||||||
|
def _count_case_statements_ast(self, parsed) -> int:
|
||||||
|
return sum(1 for _ in parsed.find_all(exp.Case))
|
||||||
|
|
||||||
|
def _count_joins_ast(self, parsed) -> int:
|
||||||
|
return sum(1 for _ in parsed.find_all(exp.Join))
|
||||||
|
|
||||||
def _count_columns(self, sql: str) -> int:
|
def _count_columns(self, sql: str) -> int:
|
||||||
"""Count approximate number of columns in SELECT."""
|
"""Count approximate number of columns in SELECT."""
|
||||||
# Find SELECT ... FROM
|
# Find SELECT ... FROM
|
||||||
|
|
@ -356,29 +393,38 @@ class OptimizationValidator:
|
||||||
|
|
||||||
def _extract_tables(self, sql: str) -> set[str]:
|
def _extract_tables(self, sql: str) -> set[str]:
|
||||||
"""Extract table names from SQL."""
|
"""Extract table names from SQL."""
|
||||||
|
parsed = _parse_sql(sql)
|
||||||
|
if parsed is not None and exp is not None:
|
||||||
|
return {t.name.lower() for t in parsed.find_all(exp.Table) if t.name}
|
||||||
|
|
||||||
tables: set[str] = set()
|
tables: set[str] = set()
|
||||||
normalized = sql.upper()
|
normalized = sql.upper()
|
||||||
|
|
||||||
# FROM table
|
|
||||||
from_matches = re.findall(r'FROM\s+([#\w]+)', normalized)
|
from_matches = re.findall(r'FROM\s+([#\w]+)', normalized)
|
||||||
tables.update(t.lower() for t in from_matches)
|
tables.update(t.lower() for t in from_matches)
|
||||||
|
|
||||||
# JOIN table
|
|
||||||
join_matches = re.findall(r'JOIN\s+([#\w]+)', normalized)
|
join_matches = re.findall(r'JOIN\s+([#\w]+)', normalized)
|
||||||
tables.update(t.lower() for t in join_matches)
|
tables.update(t.lower() for t in join_matches)
|
||||||
|
|
||||||
return tables
|
return tables
|
||||||
|
|
||||||
def _extract_column_aliases(self, sql: str) -> set[str]:
|
def _extract_column_aliases(self, sql: str) -> set[str]:
|
||||||
"""Extract column aliases from SELECT clause."""
|
"""Extract column aliases from SELECT clause."""
|
||||||
|
parsed = _parse_sql(sql)
|
||||||
|
if parsed is not None and exp is not None:
|
||||||
|
select = parsed.find(exp.Select)
|
||||||
|
if not select:
|
||||||
|
return set()
|
||||||
aliases: set[str] = set()
|
aliases: set[str] = set()
|
||||||
|
for projection in select.expressions:
|
||||||
# Pattern: column AS alias or alias = expression
|
if isinstance(projection, exp.Star):
|
||||||
as_pattern = re.findall(r'\bAS\s+(\w+)', sql, re.IGNORECASE)
|
continue
|
||||||
aliases.update(a.lower() for a in as_pattern)
|
alias = projection.alias_or_name
|
||||||
|
if alias:
|
||||||
equals_pattern = re.findall(r'^\s*(\w+)\s*=', sql, re.MULTILINE | re.IGNORECASE)
|
aliases.add(alias.lower())
|
||||||
aliases.update(a.lower() for a in equals_pattern)
|
|
||||||
|
|
||||||
return aliases
|
return aliases
|
||||||
|
|
||||||
|
fallback_aliases: set[str] = set()
|
||||||
|
as_pattern = re.findall(r'\bAS\s+(\w+)', sql, re.IGNORECASE)
|
||||||
|
fallback_aliases.update(a.lower() for a in as_pattern)
|
||||||
|
equals_pattern = re.findall(r'^\s*(\w+)\s*=', sql, re.MULTILINE | re.IGNORECASE)
|
||||||
|
fallback_aliases.update(a.lower() for a in equals_pattern)
|
||||||
|
return fallback_aliases
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,11 @@ from typing import Any
|
||||||
|
|
||||||
import sqlparse
|
import sqlparse
|
||||||
|
|
||||||
|
try:
|
||||||
|
import sqlglot # type: ignore[import-not-found]
|
||||||
|
except Exception: # pragma: no cover - optional dependency at runtime
|
||||||
|
sqlglot = None # type: ignore[assignment]
|
||||||
|
|
||||||
|
|
||||||
def _format_sql(sql_text: str) -> str:
|
def _format_sql(sql_text: str) -> str:
|
||||||
"""Format SQL text with proper indentation and line breaks.
|
"""Format SQL text with proper indentation and line breaks.
|
||||||
|
|
@ -27,12 +32,20 @@ def _format_sql(sql_text: str) -> str:
|
||||||
Returns:
|
Returns:
|
||||||
Formatted SQL with indentation
|
Formatted SQL with indentation
|
||||||
"""
|
"""
|
||||||
cleaned = sql_text.replace('\r\n', '\n').replace('\r', '\n')
|
cleaned = sql_text.replace("\r\n", "\n").replace("\r", "\n")
|
||||||
|
|
||||||
|
if sqlglot is not None:
|
||||||
|
try:
|
||||||
|
parsed = sqlglot.parse_one(cleaned, error_level="ignore")
|
||||||
|
if parsed is not None:
|
||||||
|
return parsed.sql(pretty=True).strip()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
formatted = sqlparse.format(
|
formatted = sqlparse.format(
|
||||||
cleaned,
|
cleaned,
|
||||||
reindent=True,
|
reindent=True,
|
||||||
keyword_case='upper',
|
keyword_case="upper",
|
||||||
indent_width=4,
|
indent_width=4,
|
||||||
wrap_after=80,
|
wrap_after=80,
|
||||||
)
|
)
|
||||||
|
|
@ -423,12 +436,12 @@ class QueryHistoryManager:
|
||||||
|
|
||||||
lines.extend([
|
lines.extend([
|
||||||
"",
|
"",
|
||||||
f"-- -----------------------------------------------------------------------------",
|
"-- -----------------------------------------------------------------------------",
|
||||||
f"-- Query #{idx}",
|
f"-- Query #{idx}",
|
||||||
f"-- Hash: {query_hash[:16]}...",
|
f"-- Hash: {query_hash[:16]}...",
|
||||||
f"-- First Seen: {first_seen}",
|
f"-- First Seen: {first_seen}",
|
||||||
f"-- Times Seen: {times_seen}",
|
f"-- Times Seen: {times_seen}",
|
||||||
f"-- -----------------------------------------------------------------------------",
|
"-- -----------------------------------------------------------------------------",
|
||||||
"",
|
"",
|
||||||
formatted_sql,
|
formatted_sql,
|
||||||
"",
|
"",
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,13 @@ from pathlib import Path
|
||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
import difflib
|
import difflib
|
||||||
|
|
||||||
|
try:
|
||||||
|
import sqlglot # type: ignore[import-not-found]
|
||||||
|
from sqlglot import exp # type: ignore[import-not-found]
|
||||||
|
except Exception: # pragma: no cover - optional dependency at runtime
|
||||||
|
sqlglot = None # type: ignore[assignment]
|
||||||
|
exp = None # type: ignore[assignment]
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class QueryComplexityResult:
|
class QueryComplexityResult:
|
||||||
|
|
@ -117,7 +124,7 @@ def _find_main_select_list(sql: str) -> str | None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def compute_query_complexity(sql: str) -> QueryComplexityResult:
|
def _compute_query_complexity_fallback(sql: str) -> QueryComplexityResult:
|
||||||
normalized = _strip_comments(sql)
|
normalized = _strip_comments(sql)
|
||||||
lowered = normalized.lower()
|
lowered = normalized.lower()
|
||||||
|
|
||||||
|
|
@ -149,6 +156,42 @@ def compute_query_complexity(sql: str) -> QueryComplexityResult:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_query_complexity(sql: str) -> QueryComplexityResult:
|
||||||
|
if sqlglot is None or exp is None:
|
||||||
|
return _compute_query_complexity_fallback(sql)
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed = sqlglot.parse_one(sql, error_level="ignore")
|
||||||
|
if parsed is None:
|
||||||
|
return _compute_query_complexity_fallback(sql)
|
||||||
|
|
||||||
|
select = parsed.find(exp.Select)
|
||||||
|
column_count = len(select.expressions) if select else 0
|
||||||
|
table_count = len({t.name for t in parsed.find_all(exp.Table) if t.name})
|
||||||
|
join_count = sum(1 for _ in parsed.find_all(exp.Join))
|
||||||
|
case_count = sum(1 for _ in parsed.find_all(exp.Case))
|
||||||
|
subquery_count = max(sum(1 for _ in parsed.find_all(exp.Select)) - 1, 0)
|
||||||
|
|
||||||
|
score = table_count + join_count + subquery_count + case_count
|
||||||
|
if score >= 10:
|
||||||
|
complexity = "high"
|
||||||
|
elif score >= 5:
|
||||||
|
complexity = "medium"
|
||||||
|
else:
|
||||||
|
complexity = "low"
|
||||||
|
|
||||||
|
return QueryComplexityResult(
|
||||||
|
column_count=column_count,
|
||||||
|
table_count=table_count,
|
||||||
|
subquery_count=subquery_count,
|
||||||
|
case_statement_count=case_count,
|
||||||
|
join_count=join_count,
|
||||||
|
complexity_level=complexity,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return _compute_query_complexity_fallback(sql)
|
||||||
|
|
||||||
|
|
||||||
def load_sql_from_file(path: str) -> str:
|
def load_sql_from_file(path: str) -> str:
|
||||||
file_path = Path(path).expanduser().resolve()
|
file_path = Path(path).expanduser().resolve()
|
||||||
return file_path.read_text(encoding="utf-8")
|
return file_path.read_text(encoding="utf-8")
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue