228 lines
6.6 KiB
Python
228 lines
6.6 KiB
Python
"""Utility tools for SQL handling and lightweight complexity heuristics."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
import re
|
|
from pathlib import Path
|
|
from typing import Iterable
|
|
import difflib
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class QueryComplexityResult:
|
|
column_count: int
|
|
table_count: int
|
|
subquery_count: int
|
|
case_statement_count: int
|
|
join_count: int
|
|
complexity_level: str
|
|
|
|
|
|
@dataclass
|
|
class _ScanState:
|
|
depth: int = 0
|
|
in_single: bool = False
|
|
in_double: bool = False
|
|
|
|
def update_quote(self, ch: str) -> None:
|
|
if ch == "'" and not self.in_double:
|
|
self.in_single = not self.in_single
|
|
elif ch == '"' and not self.in_single:
|
|
self.in_double = not self.in_double
|
|
|
|
def in_quotes(self) -> bool:
|
|
return self.in_single or self.in_double
|
|
|
|
def update_paren(self, ch: str) -> None:
|
|
if self.in_quotes():
|
|
return
|
|
if ch == "(":
|
|
self.depth += 1
|
|
elif ch == ")":
|
|
self.depth = max(self.depth - 1, 0)
|
|
|
|
def at_top_level(self) -> bool:
|
|
return not self.in_quotes() and self.depth == 0
|
|
|
|
|
|
def _skip_line_comment(sql: str, start: int) -> int:
|
|
end = sql.find("\n", start)
|
|
return len(sql) if end == -1 else end
|
|
|
|
|
|
def _skip_block_comment(sql: str, start: int) -> int:
|
|
end = sql.find("*/", start)
|
|
return len(sql) if end == -1 else end + 2
|
|
|
|
|
|
def _strip_comments(sql: str) -> str:
|
|
state = _ScanState()
|
|
result: list[str] = []
|
|
i = 0
|
|
length = len(sql)
|
|
while i < length:
|
|
ch = sql[i]
|
|
if state.in_quotes():
|
|
state.update_quote(ch)
|
|
result.append(ch)
|
|
i += 1
|
|
continue
|
|
if sql.startswith("--", i):
|
|
i = _skip_line_comment(sql, i + 2)
|
|
continue
|
|
if sql.startswith("/*", i):
|
|
i = _skip_block_comment(sql, i + 2)
|
|
continue
|
|
state.update_quote(ch)
|
|
result.append(ch)
|
|
i += 1
|
|
return "".join(result)
|
|
|
|
|
|
def _split_top_level_commas(segment: str) -> list[str]:
|
|
parts: list[str] = []
|
|
current: list[str] = []
|
|
state = _ScanState()
|
|
for ch in segment:
|
|
state.update_quote(ch)
|
|
state.update_paren(ch)
|
|
if ch == "," and state.at_top_level():
|
|
part = "".join(current).strip()
|
|
if part:
|
|
parts.append(part)
|
|
current = []
|
|
continue
|
|
current.append(ch)
|
|
tail = "".join(current).strip()
|
|
if tail:
|
|
parts.append(tail)
|
|
return parts
|
|
|
|
|
|
def _find_main_select_list(sql: str) -> str | None:
|
|
lowered = sql.lower()
|
|
select_match = re.search(r"\bselect\b", lowered)
|
|
if not select_match:
|
|
return None
|
|
# Find the first FROM after the first SELECT at top level
|
|
idx = select_match.end()
|
|
state = _ScanState()
|
|
for i in range(idx, len(sql)):
|
|
ch = sql[i]
|
|
state.update_quote(ch)
|
|
state.update_paren(ch)
|
|
if state.at_top_level() and lowered.startswith(" from ", i):
|
|
return sql[idx:i]
|
|
return None
|
|
|
|
|
|
def compute_query_complexity(sql: str) -> QueryComplexityResult:
|
|
normalized = _strip_comments(sql)
|
|
lowered = normalized.lower()
|
|
|
|
select_list = _find_main_select_list(normalized)
|
|
column_count = 0
|
|
if select_list:
|
|
column_count = len(_split_top_level_commas(select_list))
|
|
|
|
table_count = len(re.findall(r"\bfrom\b", lowered)) + len(re.findall(r"\bjoin\b", lowered))
|
|
join_count = len(re.findall(r"\bjoin\b", lowered))
|
|
subquery_count = len(re.findall(r"\(\s*select\b", lowered))
|
|
case_count = len(re.findall(r"\bcase\b", lowered))
|
|
|
|
score = table_count + join_count + subquery_count + case_count
|
|
if score >= 10:
|
|
complexity = "high"
|
|
elif score >= 5:
|
|
complexity = "medium"
|
|
else:
|
|
complexity = "low"
|
|
|
|
return QueryComplexityResult(
|
|
column_count=column_count,
|
|
table_count=table_count,
|
|
subquery_count=subquery_count,
|
|
case_statement_count=case_count,
|
|
join_count=join_count,
|
|
complexity_level=complexity,
|
|
)
|
|
|
|
|
|
def load_sql_from_file(path: str) -> str:
|
|
file_path = Path(path).expanduser().resolve()
|
|
return file_path.read_text(encoding="utf-8")
|
|
|
|
|
|
def normalize_sql(sql: str) -> str:
|
|
return _strip_comments(sql).strip()
|
|
|
|
|
|
def ensure_non_empty(sql: str) -> str:
|
|
normalized = normalize_sql(sql)
|
|
if not normalized:
|
|
raise ValueError("SQL vazio ou inválido.")
|
|
return sql
|
|
|
|
|
|
def list_supported_databases(values: Iterable[str]) -> str:
|
|
return ", ".join(values)
|
|
|
|
|
|
def _escape_html(text: str) -> str:
|
|
return (
|
|
text.replace("&", "&")
|
|
.replace("<", "<")
|
|
.replace(">", ">")
|
|
)
|
|
|
|
|
|
def diff_sql(original_sql: str, optimized_sql: str) -> str:
|
|
"""Generate a side-by-side diff between original and optimized SQL.
|
|
|
|
Args:
|
|
original_sql: Original SQL text
|
|
optimized_sql: Optimized SQL text
|
|
|
|
Returns:
|
|
Side-by-side diff in HTML table format
|
|
"""
|
|
original_lines = original_sql.strip().splitlines()
|
|
optimized_lines = optimized_sql.strip().splitlines()
|
|
|
|
matcher = difflib.SequenceMatcher(a=original_lines, b=optimized_lines)
|
|
rows: list[str] = []
|
|
|
|
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
|
|
if tag == "equal":
|
|
for left, right in zip(original_lines[i1:i2], optimized_lines[j1:j2]):
|
|
rows.append(
|
|
f"<tr><td><pre>{_escape_html(left)}</pre></td><td><pre>{_escape_html(right)}</pre></td></tr>"
|
|
)
|
|
elif tag == "replace":
|
|
left_chunk = original_lines[i1:i2]
|
|
right_chunk = optimized_lines[j1:j2]
|
|
max_len = max(len(left_chunk), len(right_chunk))
|
|
for idx in range(max_len):
|
|
left = left_chunk[idx] if idx < len(left_chunk) else ""
|
|
right = right_chunk[idx] if idx < len(right_chunk) else ""
|
|
rows.append(
|
|
f"<tr><td><pre>{_escape_html(left)}</pre></td><td><pre>{_escape_html(right)}</pre></td></tr>"
|
|
)
|
|
elif tag == "delete":
|
|
for left in original_lines[i1:i2]:
|
|
rows.append(
|
|
f"<tr><td><pre>{_escape_html(left)}</pre></td><td><pre></pre></td></tr>"
|
|
)
|
|
elif tag == "insert":
|
|
for right in optimized_lines[j1:j2]:
|
|
rows.append(
|
|
f"<tr><td><pre></pre></td><td><pre>{_escape_html(right)}</pre></td></tr>"
|
|
)
|
|
|
|
if not rows:
|
|
return "<table><tr><th>Original</th><th>Otimizada</th></tr></table>"
|
|
|
|
header = "<tr><th>Original</th><th>Otimizada</th></tr>"
|
|
body = "".join(rows)
|
|
return f"<table>{header}{body}</table>"
|