"""Utility tools for SQL handling and lightweight complexity heuristics.""" from __future__ import annotations from dataclasses import dataclass import re from pathlib import Path from typing import Iterable import difflib @dataclass(frozen=True) class QueryComplexityResult: column_count: int table_count: int subquery_count: int case_statement_count: int join_count: int complexity_level: str @dataclass class _ScanState: depth: int = 0 in_single: bool = False in_double: bool = False def update_quote(self, ch: str) -> None: if ch == "'" and not self.in_double: self.in_single = not self.in_single elif ch == '"' and not self.in_single: self.in_double = not self.in_double def in_quotes(self) -> bool: return self.in_single or self.in_double def update_paren(self, ch: str) -> None: if self.in_quotes(): return if ch == "(": self.depth += 1 elif ch == ")": self.depth = max(self.depth - 1, 0) def at_top_level(self) -> bool: return not self.in_quotes() and self.depth == 0 def _skip_line_comment(sql: str, start: int) -> int: end = sql.find("\n", start) return len(sql) if end == -1 else end def _skip_block_comment(sql: str, start: int) -> int: end = sql.find("*/", start) return len(sql) if end == -1 else end + 2 def _strip_comments(sql: str) -> str: state = _ScanState() result: list[str] = [] i = 0 length = len(sql) while i < length: ch = sql[i] if state.in_quotes(): state.update_quote(ch) result.append(ch) i += 1 continue if sql.startswith("--", i): i = _skip_line_comment(sql, i + 2) continue if sql.startswith("/*", i): i = _skip_block_comment(sql, i + 2) continue state.update_quote(ch) result.append(ch) i += 1 return "".join(result) def _split_top_level_commas(segment: str) -> list[str]: parts: list[str] = [] current: list[str] = [] state = _ScanState() for ch in segment: state.update_quote(ch) state.update_paren(ch) if ch == "," and state.at_top_level(): part = "".join(current).strip() if part: parts.append(part) current = [] continue current.append(ch) tail = "".join(current).strip() if tail: parts.append(tail) return parts def _find_main_select_list(sql: str) -> str | None: lowered = sql.lower() select_match = re.search(r"\bselect\b", lowered) if not select_match: return None # Find the first FROM after the first SELECT at top level idx = select_match.end() state = _ScanState() for i in range(idx, len(sql)): ch = sql[i] state.update_quote(ch) state.update_paren(ch) if state.at_top_level() and lowered.startswith(" from ", i): return sql[idx:i] return None def compute_query_complexity(sql: str) -> QueryComplexityResult: normalized = _strip_comments(sql) lowered = normalized.lower() select_list = _find_main_select_list(normalized) column_count = 0 if select_list: column_count = len(_split_top_level_commas(select_list)) table_count = len(re.findall(r"\bfrom\b", lowered)) + len(re.findall(r"\bjoin\b", lowered)) join_count = len(re.findall(r"\bjoin\b", lowered)) subquery_count = len(re.findall(r"\(\s*select\b", lowered)) case_count = len(re.findall(r"\bcase\b", lowered)) score = table_count + join_count + subquery_count + case_count if score >= 10: complexity = "high" elif score >= 5: complexity = "medium" else: complexity = "low" return QueryComplexityResult( column_count=column_count, table_count=table_count, subquery_count=subquery_count, case_statement_count=case_count, join_count=join_count, complexity_level=complexity, ) def load_sql_from_file(path: str) -> str: file_path = Path(path).expanduser().resolve() return file_path.read_text(encoding="utf-8") def normalize_sql(sql: str) -> str: return _strip_comments(sql).strip() def ensure_non_empty(sql: str) -> str: normalized = normalize_sql(sql) if not normalized: raise ValueError("SQL vazio ou inválido.") return sql def list_supported_databases(values: Iterable[str]) -> str: return ", ".join(values) def _escape_html(text: str) -> str: return ( text.replace("&", "&") .replace("<", "<") .replace(">", ">") ) def diff_sql(original_sql: str, optimized_sql: str) -> str: """Generate a side-by-side diff between original and optimized SQL. Args: original_sql: Original SQL text optimized_sql: Optimized SQL text Returns: Side-by-side diff in HTML table format """ original_lines = original_sql.strip().splitlines() optimized_lines = optimized_sql.strip().splitlines() matcher = difflib.SequenceMatcher(a=original_lines, b=optimized_lines) rows: list[str] = [] for tag, i1, i2, j1, j2 in matcher.get_opcodes(): if tag == "equal": for left, right in zip(original_lines[i1:i2], optimized_lines[j1:j2]): rows.append( f"
{_escape_html(left)}
{_escape_html(right)}
" ) elif tag == "replace": left_chunk = original_lines[i1:i2] right_chunk = optimized_lines[j1:j2] max_len = max(len(left_chunk), len(right_chunk)) for idx in range(max_len): left = left_chunk[idx] if idx < len(left_chunk) else "" right = right_chunk[idx] if idx < len(right_chunk) else "" rows.append( f"
{_escape_html(left)}
{_escape_html(right)}
" ) elif tag == "delete": for left in original_lines[i1:i2]: rows.append( f"
{_escape_html(left)}
"
                )
        elif tag == "insert":
            for right in optimized_lines[j1:j2]:
                rows.append(
                    f"
{_escape_html(right)}
" ) if not rows: return "
OriginalOtimizada
" header = "OriginalOtimizada" body = "".join(rows) return f"{header}{body}
"