poc-v1/src/sql_optimizer_team/tools/sql_tools.py

228 lines
6.6 KiB
Python

"""Utility tools for SQL handling and lightweight complexity heuristics."""
from __future__ import annotations
from dataclasses import dataclass
import re
from pathlib import Path
from typing import Iterable
import difflib
@dataclass(frozen=True)
class QueryComplexityResult:
column_count: int
table_count: int
subquery_count: int
case_statement_count: int
join_count: int
complexity_level: str
@dataclass
class _ScanState:
depth: int = 0
in_single: bool = False
in_double: bool = False
def update_quote(self, ch: str) -> None:
if ch == "'" and not self.in_double:
self.in_single = not self.in_single
elif ch == '"' and not self.in_single:
self.in_double = not self.in_double
def in_quotes(self) -> bool:
return self.in_single or self.in_double
def update_paren(self, ch: str) -> None:
if self.in_quotes():
return
if ch == "(":
self.depth += 1
elif ch == ")":
self.depth = max(self.depth - 1, 0)
def at_top_level(self) -> bool:
return not self.in_quotes() and self.depth == 0
def _skip_line_comment(sql: str, start: int) -> int:
end = sql.find("\n", start)
return len(sql) if end == -1 else end
def _skip_block_comment(sql: str, start: int) -> int:
end = sql.find("*/", start)
return len(sql) if end == -1 else end + 2
def _strip_comments(sql: str) -> str:
state = _ScanState()
result: list[str] = []
i = 0
length = len(sql)
while i < length:
ch = sql[i]
if state.in_quotes():
state.update_quote(ch)
result.append(ch)
i += 1
continue
if sql.startswith("--", i):
i = _skip_line_comment(sql, i + 2)
continue
if sql.startswith("/*", i):
i = _skip_block_comment(sql, i + 2)
continue
state.update_quote(ch)
result.append(ch)
i += 1
return "".join(result)
def _split_top_level_commas(segment: str) -> list[str]:
parts: list[str] = []
current: list[str] = []
state = _ScanState()
for ch in segment:
state.update_quote(ch)
state.update_paren(ch)
if ch == "," and state.at_top_level():
part = "".join(current).strip()
if part:
parts.append(part)
current = []
continue
current.append(ch)
tail = "".join(current).strip()
if tail:
parts.append(tail)
return parts
def _find_main_select_list(sql: str) -> str | None:
lowered = sql.lower()
select_match = re.search(r"\bselect\b", lowered)
if not select_match:
return None
# Find the first FROM after the first SELECT at top level
idx = select_match.end()
state = _ScanState()
for i in range(idx, len(sql)):
ch = sql[i]
state.update_quote(ch)
state.update_paren(ch)
if state.at_top_level() and lowered.startswith(" from ", i):
return sql[idx:i]
return None
def compute_query_complexity(sql: str) -> QueryComplexityResult:
normalized = _strip_comments(sql)
lowered = normalized.lower()
select_list = _find_main_select_list(normalized)
column_count = 0
if select_list:
column_count = len(_split_top_level_commas(select_list))
table_count = len(re.findall(r"\bfrom\b", lowered)) + len(re.findall(r"\bjoin\b", lowered))
join_count = len(re.findall(r"\bjoin\b", lowered))
subquery_count = len(re.findall(r"\(\s*select\b", lowered))
case_count = len(re.findall(r"\bcase\b", lowered))
score = table_count + join_count + subquery_count + case_count
if score >= 10:
complexity = "high"
elif score >= 5:
complexity = "medium"
else:
complexity = "low"
return QueryComplexityResult(
column_count=column_count,
table_count=table_count,
subquery_count=subquery_count,
case_statement_count=case_count,
join_count=join_count,
complexity_level=complexity,
)
def load_sql_from_file(path: str) -> str:
file_path = Path(path).expanduser().resolve()
return file_path.read_text(encoding="utf-8")
def normalize_sql(sql: str) -> str:
return _strip_comments(sql).strip()
def ensure_non_empty(sql: str) -> str:
normalized = normalize_sql(sql)
if not normalized:
raise ValueError("SQL vazio ou inválido.")
return sql
def list_supported_databases(values: Iterable[str]) -> str:
return ", ".join(values)
def _escape_html(text: str) -> str:
return (
text.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
)
def diff_sql(original_sql: str, optimized_sql: str) -> str:
"""Generate a side-by-side diff between original and optimized SQL.
Args:
original_sql: Original SQL text
optimized_sql: Optimized SQL text
Returns:
Side-by-side diff in HTML table format
"""
original_lines = original_sql.strip().splitlines()
optimized_lines = optimized_sql.strip().splitlines()
matcher = difflib.SequenceMatcher(a=original_lines, b=optimized_lines)
rows: list[str] = []
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
if tag == "equal":
for left, right in zip(original_lines[i1:i2], optimized_lines[j1:j2]):
rows.append(
f"<tr><td><pre>{_escape_html(left)}</pre></td><td><pre>{_escape_html(right)}</pre></td></tr>"
)
elif tag == "replace":
left_chunk = original_lines[i1:i2]
right_chunk = optimized_lines[j1:j2]
max_len = max(len(left_chunk), len(right_chunk))
for idx in range(max_len):
left = left_chunk[idx] if idx < len(left_chunk) else ""
right = right_chunk[idx] if idx < len(right_chunk) else ""
rows.append(
f"<tr><td><pre>{_escape_html(left)}</pre></td><td><pre>{_escape_html(right)}</pre></td></tr>"
)
elif tag == "delete":
for left in original_lines[i1:i2]:
rows.append(
f"<tr><td><pre>{_escape_html(left)}</pre></td><td><pre></pre></td></tr>"
)
elif tag == "insert":
for right in optimized_lines[j1:j2]:
rows.append(
f"<tr><td><pre></pre></td><td><pre>{_escape_html(right)}</pre></td></tr>"
)
if not rows:
return "<table><tr><th>Original</th><th>Otimizada</th></tr></table>"
header = "<tr><th>Original</th><th>Otimizada</th></tr>"
body = "".join(rows)
return f"<table>{header}{body}</table>"