Created
January 29, 2026 11:00
-
-
Save secemp9/09a20c0cbaf0942f36109755c540863b to your computer and use it in GitHub Desktop.
AST Python splitter to module
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| General-purpose AST-based Python module splitter. | |
| Analyzes any monolithic .py file, builds a full dependency graph with structural | |
| fingerprints, auto-discovers natural clusters, and produces split module files. | |
| Capabilities: | |
| - AST parsing with full symbol table (ast + inspect-style signatures) | |
| - Structural hashing (hashlib) to fingerprint each symbol for drift detection | |
| - Dependency graph with cycle detection (tarjan-style) | |
| - Auto-clustering via connectivity analysis when no spec is provided | |
| - Import filtering: only carries over what each module actually uses | |
| - Inline import hoisting: detects `import X` inside functions, promotes to top | |
| - Decorator-aware extraction: includes @dataclass, @property, etc. | |
| - Dry run with diff-style output or --write to commit | |
| Usage: | |
| python3 tools/split_modules.py <source.py> [options] | |
| # Auto-discover clusters (no config needed): | |
| python3 tools/split_modules.py niwa/cli.py --auto | |
| # Use a split spec (JSON or inline): | |
| python3 tools/split_modules.py niwa/cli.py --spec split_spec.json | |
| # Dry run (default): prints analysis + file contents to stdout | |
| python3 tools/split_modules.py niwa/cli.py --auto | |
| # Verify only (no file output): | |
| python3 tools/split_modules.py niwa/cli.py --auto --verify | |
| # Actually write the files: | |
| python3 tools/split_modules.py niwa/cli.py --auto --write | |
| # Output directory (default: same as source): | |
| python3 tools/split_modules.py niwa/cli.py --auto --outdir niwa/ | |
| """ | |
| import argparse | |
| import ast | |
| import hashlib | |
| import inspect | |
| import json | |
| import re | |
| import sys | |
| import textwrap | |
| from collections import defaultdict | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Optional | |
| # ── Data structures ───────────────────────────────────────────────────────── | |
| @dataclass | |
| class SymbolInfo: | |
| """Full metadata for a top-level symbol extracted via AST.""" | |
| name: str | |
| kind: str # 'function', 'async_function', 'class', 'constant' | |
| node: ast.AST | |
| start_line: int | |
| end_line: int | |
| num_lines: int | |
| decorators: list[str] = field(default_factory=list) | |
| bases: list[str] = field(default_factory=list) # class bases | |
| signature: Optional[str] = None # function signature string | |
| docstring: Optional[str] = None | |
| references: set[str] = field(default_factory=set) # other top-level names used | |
| inline_imports: list[tuple[str, int, int]] = field(default_factory=list) # (text, lineno, end_lineno) | |
| structural_hash: str = "" # content-based fingerprint (original) | |
| structural_hash_posthoist: str = "" # hash after stripping stdlib inline imports | |
| module: Optional[str] = None # assigned module after clustering | |
| @property | |
| def is_class(self) -> bool: | |
| return self.kind == 'class' | |
| @property | |
| def is_function(self) -> bool: | |
| return self.kind in ('function', 'async_function') | |
| @property | |
| def is_constant(self) -> bool: | |
| return self.kind == 'constant' | |
| @dataclass | |
| class ModuleFile: | |
| """A generated output module.""" | |
| name: str | |
| symbols: list[str] | |
| source: str = "" | |
| line_count: int = 0 | |
| deps: set[tuple[str, str]] = field(default_factory=set) # (module, symbol) | |
| # ── AST analysis ──────────────────────────────────────────────────────────── | |
| def parse_source(path: Path) -> tuple[str, ast.Module]: | |
| """Parse a Python source file into source text + AST.""" | |
| source = path.read_text() | |
| tree = ast.parse(source, filename=str(path)) | |
| return source, tree | |
| def compute_structural_hash(source: str, node: ast.AST, strip_stdlib_imports: bool = False) -> str: | |
| """ | |
| Compute a content-based fingerprint for a symbol's AST subtree. | |
| Uses ast.dump (which normalizes whitespace/comments away) hashed with sha256. | |
| Two symbols with identical structure produce the same hash, even if | |
| whitespace or comments differ. Useful for detecting drift between the | |
| original monolith and the split files. | |
| If strip_stdlib_imports=True, removes stdlib import nodes from function/method | |
| bodies before hashing. This produces the "expected" hash after hoisting, so | |
| original (stripped) == split (already stripped) proves isomorphism. | |
| """ | |
| if strip_stdlib_imports: | |
| node = _strip_stdlib_import_nodes(node) | |
| # ast.dump gives a canonical string repr of the subtree | |
| canonical = ast.dump(node, annotate_fields=True, include_attributes=False) | |
| return hashlib.sha256(canonical.encode()).hexdigest()[:12] | |
| def _strip_stdlib_import_nodes(node: ast.AST) -> ast.AST: | |
| """Return a deep copy of node with stdlib import statements removed from ALL bodies. | |
| Recursively strips stdlib imports from every statement list (function bodies, | |
| if/for/while/with/try bodies, except handlers, etc.) at any nesting depth. | |
| This lets us hash the "post-hoist" version of a symbol from the original AST, | |
| so it can be compared against the hash of the same symbol in the split file | |
| (where those imports were physically removed). | |
| """ | |
| import copy | |
| node = copy.deepcopy(node) | |
| def _filter_body(body: list) -> list: | |
| return [stmt for stmt in body if not _is_stdlib_import_stmt(stmt)] | |
| def _recurse(n): | |
| """Strip stdlib imports from any node that has statement lists, recursively.""" | |
| # All AST node types that contain statement lists | |
| body_attrs = ('body', 'orelse', 'finalbody') | |
| for attr in body_attrs: | |
| if hasattr(n, attr): | |
| stmts = getattr(n, attr) | |
| if isinstance(stmts, list): | |
| setattr(n, attr, _filter_body(stmts)) | |
| for child in getattr(n, attr): | |
| _recurse(child) | |
| # Try/ExceptHandler handlers | |
| if hasattr(n, 'handlers'): | |
| for handler in n.handlers: | |
| handler.body = _filter_body(handler.body) | |
| for child in handler.body: | |
| _recurse(child) | |
| _recurse(node) | |
| return node | |
| def _is_stdlib_import_stmt(stmt: ast.AST) -> bool: | |
| """Check if an AST statement node is a stdlib import.""" | |
| if isinstance(stmt, ast.Import): | |
| return all( | |
| (alias.asname or alias.name).split('.')[0] in STDLIB_MODULES | |
| for alias in stmt.names | |
| ) | |
| elif isinstance(stmt, ast.ImportFrom): | |
| if stmt.module: | |
| return stmt.module.split('.')[0] in STDLIB_MODULES | |
| return False | |
| def extract_signature(node: ast.AST) -> Optional[str]: | |
| """Extract a human-readable function signature from AST.""" | |
| if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): | |
| return None | |
| args = node.args | |
| parts = [] | |
| # Positional args | |
| num_defaults = len(args.defaults) | |
| num_positional = len(args.args) | |
| for i, arg in enumerate(args.args): | |
| name = arg.arg | |
| annotation = "" | |
| if arg.annotation: | |
| annotation = f": {ast.unparse(arg.annotation)}" | |
| # Check if this arg has a default | |
| default_idx = i - (num_positional - num_defaults) | |
| if default_idx >= 0: | |
| default = ast.unparse(args.defaults[default_idx]) | |
| parts.append(f"{name}{annotation}={default}") | |
| else: | |
| parts.append(f"{name}{annotation}") | |
| # *args | |
| if args.vararg: | |
| va = args.vararg | |
| ann = f": {ast.unparse(va.annotation)}" if va.annotation else "" | |
| parts.append(f"*{va.arg}{ann}") | |
| elif args.kwonlyargs: | |
| parts.append("*") | |
| # keyword-only | |
| for i, arg in enumerate(args.kwonlyargs): | |
| ann = f": {ast.unparse(arg.annotation)}" if arg.annotation else "" | |
| if i < len(args.kw_defaults) and args.kw_defaults[i] is not None: | |
| default = ast.unparse(args.kw_defaults[i]) | |
| parts.append(f"{arg.arg}{ann}={default}") | |
| else: | |
| parts.append(f"{arg.arg}{ann}") | |
| # **kwargs | |
| if args.kwarg: | |
| kw = args.kwarg | |
| ann = f": {ast.unparse(kw.annotation)}" if kw.annotation else "" | |
| parts.append(f"**{kw.arg}{ann}") | |
| # Return annotation | |
| ret = "" | |
| if node.returns: | |
| ret = f" -> {ast.unparse(node.returns)}" | |
| prefix = "async def" if isinstance(node, ast.AsyncFunctionDef) else "def" | |
| return f"{prefix} {node.name}({', '.join(parts)}){ret}" | |
| def extract_decorators(node: ast.AST) -> list[str]: | |
| """Extract decorator names from a class or function.""" | |
| if not hasattr(node, 'decorator_list'): | |
| return [] | |
| decorators = [] | |
| for dec in node.decorator_list: | |
| try: | |
| decorators.append(ast.unparse(dec)) | |
| except Exception: | |
| decorators.append("(unknown)") | |
| return decorators | |
| def extract_docstring(node: ast.AST) -> Optional[str]: | |
| """Extract the docstring from a function/class/module node.""" | |
| if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module)): | |
| if (node.body and isinstance(node.body[0], ast.Expr) | |
| and isinstance(node.body[0].value, (ast.Constant, ast.Str))): | |
| val = node.body[0].value | |
| if isinstance(val, ast.Constant) and isinstance(val.value, str): | |
| return val.value | |
| elif isinstance(val, ast.Str): | |
| return val.s | |
| return None | |
| def extract_class_bases(node: ast.AST) -> list[str]: | |
| """Extract base class names.""" | |
| if not isinstance(node, ast.ClassDef): | |
| return [] | |
| return [ast.unparse(b) for b in node.bases] | |
| def find_inline_imports(node: ast.AST) -> list[tuple[str, int, int]]: | |
| """Find import statements inside function/method bodies (not at module level). | |
| Returns list of (import_text, lineno, end_lineno) tuples. | |
| """ | |
| imports = [] | |
| def _collect(parent): | |
| for child in ast.walk(parent): | |
| if child is parent: | |
| continue | |
| if isinstance(child, (ast.Import, ast.ImportFrom)): | |
| try: | |
| text = ast.unparse(child) | |
| imports.append((text, child.lineno, child.end_lineno or child.lineno)) | |
| except Exception: | |
| pass | |
| if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): | |
| _collect(node) | |
| elif isinstance(node, ast.ClassDef): | |
| for method in node.body: | |
| if isinstance(method, (ast.FunctionDef, ast.AsyncFunctionDef)): | |
| _collect(method) | |
| return imports | |
| def _get_stdlib_modules() -> frozenset[str]: | |
| """Return the set of standard library module names for the running Python.""" | |
| return frozenset(sys.stdlib_module_names) | |
| STDLIB_MODULES = _get_stdlib_modules() | |
| def is_stdlib_import(import_text: str) -> bool: | |
| """Check if an import statement refers to a stdlib module. | |
| Handles both 'import X' and 'from X import Y' forms. | |
| """ | |
| import_text = import_text.strip() | |
| if import_text.startswith('from '): | |
| # 'from datetime import datetime' -> 'datetime' | |
| parts = import_text.split() | |
| if len(parts) >= 2: | |
| mod = parts[1].split('.')[0] | |
| return mod in STDLIB_MODULES | |
| elif import_text.startswith('import '): | |
| # 'import sys' -> 'sys', 'import os.path' -> 'os' | |
| parts = import_text.split() | |
| if len(parts) >= 2: | |
| mod = parts[1].split('.')[0].rstrip(',') | |
| return mod in STDLIB_MODULES | |
| return False | |
| def collect_references(node: ast.AST, known_names: set[str], own_name: str) -> set[str]: | |
| """Find which known top-level names are referenced inside a node.""" | |
| target_names = known_names - {own_name} | |
| refs = set() | |
| class RefVisitor(ast.NodeVisitor): | |
| def visit_Name(self, n): | |
| if n.id in target_names: | |
| refs.add(n.id) | |
| self.generic_visit(n) | |
| RefVisitor().visit(node) | |
| return refs | |
| def analyze_file(source: str, tree: ast.Module) -> tuple[list[SymbolInfo], list[ast.AST]]: | |
| """ | |
| Full AST analysis: extract all top-level symbols with metadata, | |
| and all top-level imports. | |
| """ | |
| symbols = [] | |
| top_level_imports = [] | |
| lines = source.splitlines() | |
| # Collect top-level defs | |
| for node in ast.iter_child_nodes(tree): | |
| if isinstance(node, (ast.Import, ast.ImportFrom)): | |
| top_level_imports.append(node) | |
| continue | |
| # try: import X / except ImportError: ... blocks | |
| if isinstance(node, ast.Try) and node.body and all( | |
| isinstance(stmt, (ast.Import, ast.ImportFrom)) for stmt in node.body | |
| ): | |
| top_level_imports.append(node) | |
| continue | |
| info = None | |
| if isinstance(node, ast.FunctionDef): | |
| info = SymbolInfo( | |
| name=node.name, | |
| kind='function', | |
| node=node, | |
| start_line=node.lineno, | |
| end_line=node.end_lineno, | |
| num_lines=node.end_lineno - node.lineno + 1, | |
| decorators=extract_decorators(node), | |
| signature=extract_signature(node), | |
| docstring=extract_docstring(node), | |
| inline_imports=find_inline_imports(node), | |
| structural_hash=compute_structural_hash(source, node), | |
| ) | |
| elif isinstance(node, ast.AsyncFunctionDef): | |
| info = SymbolInfo( | |
| name=node.name, | |
| kind='async_function', | |
| node=node, | |
| start_line=node.lineno, | |
| end_line=node.end_lineno, | |
| num_lines=node.end_lineno - node.lineno + 1, | |
| decorators=extract_decorators(node), | |
| signature=extract_signature(node), | |
| docstring=extract_docstring(node), | |
| inline_imports=find_inline_imports(node), | |
| structural_hash=compute_structural_hash(source, node), | |
| ) | |
| elif isinstance(node, ast.ClassDef): | |
| info = SymbolInfo( | |
| name=node.name, | |
| kind='class', | |
| node=node, | |
| start_line=node.lineno, | |
| end_line=node.end_lineno, | |
| num_lines=node.end_lineno - node.lineno + 1, | |
| decorators=extract_decorators(node), | |
| bases=extract_class_bases(node), | |
| docstring=extract_docstring(node), | |
| inline_imports=find_inline_imports(node), | |
| structural_hash=compute_structural_hash(source, node), | |
| ) | |
| elif isinstance(node, ast.Assign): | |
| for t in node.targets: | |
| if isinstance(t, ast.Name): | |
| info = SymbolInfo( | |
| name=t.id, | |
| kind='constant', | |
| node=node, | |
| start_line=node.lineno, | |
| end_line=node.end_lineno, | |
| num_lines=node.end_lineno - node.lineno + 1, | |
| structural_hash=compute_structural_hash(source, node), | |
| ) | |
| break # only first target | |
| if info: | |
| # Compute post-hoist hash (strips stdlib inline imports from AST before hashing) | |
| info.structural_hash_posthoist = compute_structural_hash( | |
| source, info.node, strip_stdlib_imports=True | |
| ) | |
| symbols.append(info) | |
| # Now compute cross-references | |
| known_names = {s.name for s in symbols} | |
| for sym in symbols: | |
| sym.references = collect_references(sym.node, known_names, sym.name) | |
| return symbols, top_level_imports | |
| # ── Clustering ────────────────────────────────────────────────────────────── | |
| def auto_cluster(symbols: list[SymbolInfo]) -> dict[str, list[str]]: | |
| """ | |
| Auto-discover module clusters via graph-based hierarchical agglomerative | |
| clustering with modularity-based stopping. | |
| General-purpose — no keyword matching, no name-based heuristics. | |
| Works on any Python codebase regardless of domain. | |
| Algorithm: | |
| 1. Build a weighted undirected affinity graph from the directed dep graph: | |
| - Direct reference (A uses B): weight 1.0 | |
| - Shared reference (A and B both use C): weight 0.5 | |
| - Co-referenced (C uses both A and B): weight 0.3 | |
| - Kind affinity (both enums, both dataclasses): weight 0.2 | |
| 2. Seed: group enums + dataclasses together (universal convention) | |
| 3. Hierarchical agglomerative merge: repeatedly merge the two clusters | |
| with the highest inter-cluster affinity, normalized by size. | |
| 4. Stop when the best merge score drops below a threshold (modularity | |
| gain becomes negative) or all clusters are within a target size range. | |
| 5. Name modules from their contents (heaviest symbol, dominant class, etc.) | |
| """ | |
| if not symbols: | |
| return {} | |
| names = [s.name for s in symbols] | |
| sym_by_name = {s.name: s for s in symbols} | |
| n = len(names) | |
| name_to_idx = {name: i for i, name in enumerate(names)} | |
| # ── Step 1: Build affinity matrix (symmetric) ── | |
| # affinity[i][j] = how strongly symbol i and j should be in the same module | |
| affinity = [[0.0] * n for _ in range(n)] | |
| # 1a. Direct references: A references B → edge(A, B) += 1.0 | |
| for s in symbols: | |
| i = name_to_idx[s.name] | |
| for ref in s.references: | |
| if ref in name_to_idx: | |
| j = name_to_idx[ref] | |
| affinity[i][j] += 1.0 | |
| affinity[j][i] += 1.0 | |
| # 1b. Shared references: A and B both reference C → edge(A, B) += 0.5 | |
| # Build reverse index: who references each symbol? | |
| referenced_by = defaultdict(set) | |
| for s in symbols: | |
| for ref in s.references: | |
| if ref in name_to_idx: | |
| referenced_by[ref].add(s.name) | |
| for target, referrers in referenced_by.items(): | |
| referrer_list = sorted(referrers) | |
| for ai in range(len(referrer_list)): | |
| for bi in range(ai + 1, len(referrer_list)): | |
| i = name_to_idx[referrer_list[ai]] | |
| j = name_to_idx[referrer_list[bi]] | |
| affinity[i][j] += 0.5 | |
| affinity[j][i] += 0.5 | |
| # 1c. Co-referenced: C references both A and B → edge(A, B) += 0.3 | |
| for s in symbols: | |
| refs = sorted(r for r in s.references if r in name_to_idx) | |
| for ai in range(len(refs)): | |
| for bi in range(ai + 1, len(refs)): | |
| i = name_to_idx[refs[ai]] | |
| j = name_to_idx[refs[bi]] | |
| affinity[i][j] += 0.3 | |
| affinity[j][i] += 0.3 | |
| # 1d. Kind affinity: same structural kind gets a small boost | |
| for ai in range(n): | |
| for bi in range(ai + 1, n): | |
| sa, sb = symbols[ai], symbols[bi] | |
| # Enum+Enum or dataclass+dataclass | |
| a_is_data = (sa.is_class and ('Enum' in sa.bases or | |
| any(d in ('dataclass', 'dataclasses.dataclass') for d in sa.decorators))) | |
| b_is_data = (sb.is_class and ('Enum' in sb.bases or | |
| any(d in ('dataclass', 'dataclasses.dataclass') for d in sb.decorators))) | |
| if a_is_data and b_is_data: | |
| affinity[ai][bi] += 0.2 | |
| affinity[bi][ai] += 0.2 | |
| # ── Step 2: Initialize clusters ── | |
| # Each symbol starts as its own cluster (index-based) | |
| # cluster_id -> set of symbol indices | |
| clusters: dict[int, set[int]] = {i: {i} for i in range(n)} | |
| next_cluster_id = n | |
| # Seed: merge all enums + dataclasses into one initial cluster | |
| data_type_indices = [] | |
| for s in symbols: | |
| if s.is_class and ('Enum' in s.bases or | |
| any(d in ('dataclass', 'dataclasses.dataclass') for d in s.decorators)): | |
| data_type_indices.append(name_to_idx[s.name]) | |
| if len(data_type_indices) > 1: | |
| # Merge all data type clusters into one | |
| merged = set() | |
| ids_to_remove = [] | |
| for idx in data_type_indices: | |
| # Find which cluster contains this index | |
| for cid, members in clusters.items(): | |
| if idx in members: | |
| merged |= members | |
| ids_to_remove.append(cid) | |
| break | |
| for cid in set(ids_to_remove): | |
| del clusters[cid] | |
| clusters[next_cluster_id] = merged | |
| next_cluster_id += 1 | |
| # ── Step 3: Hierarchical agglomerative merging ── | |
| # Total edge weight in the graph (for modularity calculation) | |
| total_weight = sum(affinity[i][j] for i in range(n) for j in range(i + 1, n)) | |
| if total_weight == 0: | |
| total_weight = 1.0 # avoid division by zero for isolated symbols | |
| def cluster_affinity(c1: set[int], c2: set[int]) -> float: | |
| """Average-linkage affinity between two clusters.""" | |
| if not c1 or not c2: | |
| return 0.0 | |
| total = sum(affinity[i][j] for i in c1 for j in c2) | |
| return total / (len(c1) * len(c2)) | |
| def modularity_gain(c1: set[int], c2: set[int]) -> float: | |
| """Estimate modularity gain from merging c1 and c2. | |
| Based on Newman-Girvan modularity: edges within the merged cluster | |
| minus expected edges if connections were random. | |
| """ | |
| inter = sum(affinity[i][j] for i in c1 for j in c2) | |
| # Degree of each cluster (sum of all edges from cluster to anywhere) | |
| deg1 = sum(affinity[i][j] for i in c1 for j in range(n) if j not in c1) | |
| deg2 = sum(affinity[i][j] for i in c2 for j in range(n) if j not in c2) | |
| # Modularity delta: actual inter-edges minus expected | |
| expected = (deg1 * deg2) / (2 * total_weight) if total_weight > 0 else 0 | |
| return inter - expected | |
| # Size constraints | |
| # Target: each module should be a manageable chunk. The original file is | |
| # being split because it's too large, so aim for modules well below the | |
| # original size. We use original_lines/3 as a soft cap to encourage | |
| # splitting into at least 3 modules, with an absolute floor of 200 lines | |
| # (don't bother splitting tiny files) and ceiling of 1500 lines. | |
| original_lines = sum(s.num_lines for s in symbols) | |
| min_lines = 50 | |
| max_lines = max(200, min(1500, original_lines // 3)) | |
| def cluster_lines(c: set[int]) -> int: | |
| return sum(symbols[i].num_lines for i in c) | |
| # Merge loop | |
| max_iterations = n * n # safety bound | |
| for _ in range(max_iterations): | |
| if len(clusters) <= 1: | |
| break | |
| # Find best merge candidate | |
| best_score = float('-inf') | |
| best_pair = None | |
| cluster_ids = list(clusters.keys()) | |
| for ai in range(len(cluster_ids)): | |
| for bi in range(ai + 1, len(cluster_ids)): | |
| cid_a, cid_b = cluster_ids[ai], cluster_ids[bi] | |
| ca, cb = clusters[cid_a], clusters[cid_b] | |
| # Skip if merge would exceed max size | |
| merged_lines = cluster_lines(ca) + cluster_lines(cb) | |
| if merged_lines > max_lines: | |
| continue | |
| # Score = modularity gain + affinity, penalized by size | |
| mg = modularity_gain(ca, cb) | |
| aff = cluster_affinity(ca, cb) | |
| # Size penalty: as merged size approaches max_lines, penalize | |
| # heavily so large clusters resist absorbing more symbols. | |
| # Ratio 0→1 maps to penalty 0→-∞ (via exponential). | |
| size_ratio = merged_lines / max_lines | |
| size_penalty = (size_ratio ** 3) * 2.0 # cubic ramp | |
| # Combined score: modularity-driven, affinity tiebreaker, size brake | |
| score = mg + aff * 0.1 - size_penalty | |
| if score > best_score: | |
| best_score = score | |
| best_pair = (cid_a, cid_b) | |
| # Stop if no beneficial merge exists | |
| if best_pair is None or best_score < 0: | |
| break | |
| # Merge | |
| cid_a, cid_b = best_pair | |
| merged = clusters[cid_a] | clusters[cid_b] | |
| del clusters[cid_a] | |
| del clusters[cid_b] | |
| clusters[next_cluster_id] = merged | |
| next_cluster_id += 1 | |
| # ── Step 4: Absorb tiny clusters ── | |
| # Clusters below min_lines get merged into their most-connected neighbor | |
| changed = True | |
| while changed: | |
| changed = False | |
| tiny = [(cid, members) for cid, members in clusters.items() | |
| if cluster_lines(members) < min_lines and len(clusters) > 1] | |
| for cid, members in tiny: | |
| # Find most-connected other cluster | |
| best_target = None | |
| best_aff = -1 | |
| for other_cid, other_members in clusters.items(): | |
| if other_cid == cid: | |
| continue | |
| aff = cluster_affinity(members, other_members) | |
| if aff > best_aff: | |
| best_aff = aff | |
| best_target = other_cid | |
| if best_target is not None: | |
| clusters[best_target] |= members | |
| del clusters[cid] | |
| changed = True | |
| break # restart scan since dict changed | |
| # ── Step 5: Name modules from contents ── | |
| result = {} | |
| used_names = set() | |
| for cid, member_indices in clusters.items(): | |
| member_syms = [symbols[i] for i in sorted(member_indices)] | |
| name = _derive_module_name(member_syms, used_names) | |
| used_names.add(name) | |
| result[name] = [s.name for s in member_syms] | |
| return result | |
| def _derive_module_name(member_syms: list[SymbolInfo], used_names: set[str]) -> str: | |
| """Derive a module name from the symbols inside a cluster. | |
| Strategy: find the name that best describes the *group*, not just the | |
| largest member. A module named after one function is a code smell — it | |
| means the cluster has no coherent identity. | |
| Priority: | |
| 1. All enums/dataclasses → "models" | |
| 2. Dominant class (>50% of lines) → snake_case(ClassName) | |
| 3. All constants → "constants" | |
| 4. Find common semantic root across member names (longest common prefix/suffix) | |
| 5. Name after the most-referenced symbol (the "hub" others depend on) | |
| 6. Fallback → "module_N" | |
| """ | |
| if not member_syms: | |
| return _unique_name("module", used_names) | |
| total_lines = sum(s.num_lines for s in member_syms) | |
| # 1. All data types → "models" | |
| all_data = all( | |
| s.is_class and ('Enum' in s.bases or | |
| any(d in ('dataclass', 'dataclasses.dataclass') for d in s.decorators)) | |
| for s in member_syms | |
| ) | |
| if all_data: | |
| return _unique_name("models", used_names) | |
| # 2. Dominant class (>50% of lines) | |
| classes = [s for s in member_syms if s.is_class] | |
| for cls in sorted(classes, key=lambda s: s.num_lines, reverse=True): | |
| if cls.num_lines > total_lines * 0.5: | |
| return _unique_name(_slugify(cls.name), used_names) | |
| # 3. All constants | |
| if all(s.is_constant for s in member_syms): | |
| return _unique_name("constants", used_names) | |
| # 4. Semantic root: split names into words, find shared words | |
| candidate = _find_common_theme(member_syms) | |
| if candidate: | |
| return _unique_name(candidate, used_names) | |
| # 5. Hub symbol: the one most-referenced by other members in this cluster | |
| member_names = {s.name for s in member_syms} | |
| ref_counts = defaultdict(int) | |
| for s in member_syms: | |
| for ref in s.references: | |
| if ref in member_names: | |
| ref_counts[ref] += 1 | |
| if ref_counts: | |
| hub = max(ref_counts, key=ref_counts.get) | |
| hub_sym = next(s for s in member_syms if s.name == hub) | |
| if hub_sym.is_class: | |
| return _unique_name(_slugify(hub), used_names) | |
| # 6. Composition-based: name by what kinds of symbols are present | |
| functions = [s for s in member_syms if s.is_function] | |
| constants = [s for s in member_syms if s.is_constant] | |
| classes = [s for s in member_syms if s.is_class] | |
| # Single-symbol cluster | |
| if len(member_syms) == 1: | |
| s = member_syms[0] | |
| slug = _slugify(s.name) | |
| if s.is_class: | |
| return _unique_name(slug, used_names) | |
| # For a single function, "cli" if it's main/entry-point-like | |
| if s.is_function and s.name in ('main', 'cli', 'run', 'app', 'entry'): | |
| return _unique_name("cli", used_names) | |
| if len(slug) <= 15: | |
| return _unique_name(slug, used_names) | |
| # Functions + constants mixed → "core" (they're the glue) | |
| if functions and constants and not classes: | |
| return _unique_name("core", used_names) | |
| # Only functions | |
| if functions and not constants and not classes: | |
| # Try to find if there's a single entry point referencing others | |
| entry = [f for f in functions if not any( | |
| f.name in other.references for other in functions if other is not f)] | |
| if len(entry) == 1: | |
| slug = _slugify(entry[0].name) | |
| if slug in ('main', 'cli', 'run', 'app'): | |
| return _unique_name("cli", used_names) | |
| if len(slug) <= 15: | |
| return _unique_name(slug, used_names) | |
| return _unique_name("utils", used_names) | |
| # Only constants | |
| if constants and not functions and not classes: | |
| return _unique_name("constants", used_names) | |
| return _unique_name("module", used_names) | |
| def _find_common_theme(syms: list[SymbolInfo]) -> Optional[str]: | |
| """Find a common semantic theme across symbol names. | |
| Splits each name into words (snake_case or CamelCase), finds words that | |
| appear in multiple names. Returns the most common shared word as a | |
| module name candidate. | |
| """ | |
| word_counts = defaultdict(int) | |
| total = len(syms) | |
| for s in syms: | |
| # Split name into words: "handle_hook_event" → ["handle", "hook", "event"] | |
| # Also handle CamelCase: "ConflictAnalysis" → ["conflict", "analysis"] | |
| slug = _slugify(s.name) | |
| words = set(slug.split('_')) | |
| # Remove generic words that don't carry meaning | |
| generic = {'get', 'set', 'make', 'create', 'build', 'do', 'run', 'is', | |
| 'has', 'can', 'the', 'a', 'an', 'to', 'from', 'for', 'with', | |
| 'handle', 'process', 'generate', 'print', 'show', 'display', | |
| 'check', 'find', 'search', 'update', 'delete', 'add', 'remove', | |
| 'init', 'setup', 'config', 'main', 'type', 'info', 'data', | |
| 'result', 'error', 'name', 'value', 'list', 'item', 'node'} | |
| words -= generic | |
| for w in words: | |
| if len(w) > 2: # skip tiny fragments | |
| word_counts[w] += 1 | |
| if not word_counts: | |
| return None | |
| # A word must appear in at least 40% of members to be a theme | |
| threshold = max(2, total * 0.4) | |
| candidates = [(word, count) for word, count in word_counts.items() | |
| if count >= threshold] | |
| if not candidates: | |
| return None | |
| # Return the most common word | |
| best_word = max(candidates, key=lambda x: x[1])[0] | |
| return best_word | |
| def _unique_name(base: str, used: set[str]) -> str: | |
| """Return base if available, else base_2, base_3, etc.""" | |
| if base not in used: | |
| return base | |
| i = 2 | |
| while f"{base}_{i}" in used: | |
| i += 1 | |
| return f"{base}_{i}" | |
| def _slugify(name: str) -> str: | |
| """Convert CamelCase to snake_case for module names.""" | |
| s = re.sub(r'([A-Z]+)([A-Z][a-z])', r'\1_\2', name) | |
| s = re.sub(r'([a-z\d])([A-Z])', r'\1_\2', s) | |
| return s.lower() | |
| # ── Dependency graph ──────────────────────────────────────────────────────── | |
| def build_module_deps( | |
| symbols: list[SymbolInfo], | |
| module_spec: dict[str, list[str]], | |
| ) -> dict[str, set[tuple[str, str]]]: | |
| """Build module-level dependency graph: module -> set of (dep_module, dep_symbol).""" | |
| sym_to_mod = {} | |
| for mod, sym_names in module_spec.items(): | |
| for sn in sym_names: | |
| sym_to_mod[sn] = mod | |
| sym_by_name = {s.name: s for s in symbols} | |
| module_deps = defaultdict(set) | |
| for mod, sym_names in module_spec.items(): | |
| for sn in sym_names: | |
| s = sym_by_name.get(sn) | |
| if not s: | |
| continue | |
| for ref in s.references: | |
| ref_mod = sym_to_mod.get(ref) | |
| if ref_mod and ref_mod != mod: | |
| module_deps[mod].add((ref_mod, ref)) | |
| return dict(module_deps) | |
| def detect_cycles(module_deps: dict[str, set[tuple[str, str]]]) -> list[list[str]]: | |
| """ | |
| Detect cycles using iterative DFS (Tarjan-style). | |
| Returns list of cycles, where each cycle is a list of module names. | |
| """ | |
| edges = defaultdict(set) | |
| for mod, deps in module_deps.items(): | |
| for dep_mod, _ in deps: | |
| edges[mod].add(dep_mod) | |
| all_nodes = set(edges.keys()) | |
| for targets in edges.values(): | |
| all_nodes |= targets | |
| WHITE, GRAY, BLACK = 0, 1, 2 | |
| color = {n: WHITE for n in all_nodes} | |
| parent = {n: None for n in all_nodes} | |
| cycles = [] | |
| for start in all_nodes: | |
| if color[start] != WHITE: | |
| continue | |
| stack = [(start, iter(edges.get(start, set())))] | |
| color[start] = GRAY | |
| while stack: | |
| node, children = stack[-1] | |
| try: | |
| child = next(children) | |
| if color[child] == GRAY: | |
| # Found cycle — trace back | |
| cycle = [child] | |
| for frame_node, _ in reversed(stack): | |
| cycle.append(frame_node) | |
| if frame_node == child: | |
| break | |
| cycles.append(cycle) | |
| elif color[child] == WHITE: | |
| color[child] = GRAY | |
| parent[child] = node | |
| stack.append((child, iter(edges.get(child, set())))) | |
| except StopIteration: | |
| color[node] = BLACK | |
| stack.pop() | |
| return cycles | |
| def topological_sort(module_deps: dict[str, set[tuple[str, str]]], all_modules: list[str]) -> list[str]: | |
| """Sort modules so dependencies come first.""" | |
| edges = defaultdict(set) | |
| for mod, deps in module_deps.items(): | |
| for dep_mod, _ in deps: | |
| edges[mod].add(dep_mod) | |
| visited = set() | |
| order = [] | |
| def visit(node): | |
| if node in visited: | |
| return | |
| visited.add(node) | |
| for dep in edges.get(node, set()): | |
| visit(dep) | |
| order.append(node) | |
| for m in all_modules: | |
| visit(m) | |
| return order | |
| # ── Import analysis ───────────────────────────────────────────────────────── | |
| def filter_imports_for_symbols( | |
| all_imports: list[ast.AST], | |
| source: str, | |
| sym_names: list[str], | |
| symbols: list[SymbolInfo], | |
| ) -> list[str]: | |
| """Determine which top-level imports a set of symbols actually needs.""" | |
| sym_by_name = {s.name: s for s in symbols} | |
| # Collect all names used in these symbols | |
| used_names = set() | |
| used_modules = set() # for `module.attr` patterns | |
| for sn in sym_names: | |
| s = sym_by_name.get(sn) | |
| if not s: | |
| continue | |
| class UsageVisitor(ast.NodeVisitor): | |
| def visit_Name(self, n): | |
| used_names.add(n.id) | |
| self.generic_visit(n) | |
| def visit_Attribute(self, n): | |
| if isinstance(n.value, ast.Name): | |
| used_modules.add(n.value.id) | |
| self.generic_visit(n) | |
| UsageVisitor().visit(s.node) | |
| all_used = used_names | used_modules | |
| # Filter top-level imports | |
| lines = source.splitlines() | |
| needed = [] | |
| seen = set() | |
| for imp_node in all_imports: | |
| if isinstance(imp_node, ast.Import): | |
| for alias in imp_node.names: | |
| bound_name = alias.asname or alias.name | |
| if bound_name in all_used or bound_name.split('.')[0] in all_used: | |
| text = _extract_lines(lines, imp_node.lineno, imp_node.end_lineno) | |
| if text not in seen: | |
| needed.append(text) | |
| seen.add(text) | |
| break | |
| elif isinstance(imp_node, ast.ImportFrom): | |
| relevant = [] | |
| for alias in imp_node.names: | |
| bound_name = alias.asname or alias.name | |
| if bound_name in all_used: | |
| relevant.append(alias) | |
| if relevant: | |
| if len(relevant) < len(imp_node.names): | |
| # Reconstruct with only needed names | |
| names_str = ', '.join( | |
| f"{a.name} as {a.asname}" if a.asname else a.name | |
| for a in relevant | |
| ) | |
| text = f"from {imp_node.module} import {names_str}" | |
| else: | |
| text = _extract_lines(lines, imp_node.lineno, imp_node.end_lineno) | |
| if text not in seen: | |
| needed.append(text) | |
| seen.add(text) | |
| elif isinstance(imp_node, ast.Try): | |
| # try: import X / except ImportError: ... blocks | |
| # Check if any import inside the try body is used | |
| any_used = False | |
| for stmt in imp_node.body: | |
| if isinstance(stmt, ast.Import): | |
| for alias in stmt.names: | |
| bound_name = alias.asname or alias.name | |
| if bound_name in all_used or bound_name.split('.')[0] in all_used: | |
| any_used = True | |
| break | |
| elif isinstance(stmt, ast.ImportFrom): | |
| for alias in stmt.names: | |
| bound_name = alias.asname or alias.name | |
| if bound_name in all_used: | |
| any_used = True | |
| break | |
| if any_used: | |
| break | |
| if any_used: | |
| # Emit the entire try/except block as-is | |
| text = _extract_lines(lines, imp_node.lineno, imp_node.end_lineno) | |
| if text not in seen: | |
| needed.append(text) | |
| seen.add(text) | |
| return needed | |
| def _extract_lines(lines: list[str], start: int, end: int) -> str: | |
| """Extract source lines (1-based, inclusive).""" | |
| return '\n'.join(lines[start - 1:end]) | |
| # ── Code generation ───────────────────────────────────────────────────────── | |
| def generate_module( | |
| mod_name: str, | |
| sym_names: list[str], | |
| source: str, | |
| symbols: list[SymbolInfo], | |
| all_imports: list[ast.AST], | |
| module_deps: dict[str, set[tuple[str, str]]], | |
| package_name: str, | |
| ) -> ModuleFile: | |
| """Generate a complete module file. | |
| Hoists stdlib inline imports to module top-level and strips them from | |
| function bodies. Non-stdlib inline imports are left in place. | |
| """ | |
| sym_by_name = {s.name: s for s in symbols} | |
| lines = source.splitlines() | |
| parts = [] | |
| # Collect stdlib inline imports to hoist (deduplicated) | |
| hoisted_imports: list[str] = [] | |
| hoisted_seen: set[str] = set() | |
| # Track which source lines to strip (1-based line numbers) | |
| strip_lines: set[int] = set() | |
| for sn in sym_names: | |
| s = sym_by_name.get(sn) | |
| if not s: | |
| continue | |
| for imp_text, lineno, end_lineno in s.inline_imports: | |
| if is_stdlib_import(imp_text) and imp_text not in hoisted_seen: | |
| hoisted_imports.append(imp_text) | |
| hoisted_seen.add(imp_text) | |
| if is_stdlib_import(imp_text): | |
| for ln in range(lineno, end_lineno + 1): | |
| strip_lines.add(ln) | |
| # Header | |
| parts.append(f'"""{package_name}.{mod_name} - Auto-split module"""') | |
| parts.append('') | |
| # Stdlib/third-party imports (from original top-level) | |
| needed = filter_imports_for_symbols(all_imports, source, sym_names, symbols) | |
| if needed: | |
| parts.extend(needed) | |
| # Hoisted stdlib imports from function bodies | |
| if hoisted_imports: | |
| for imp in hoisted_imports: | |
| if imp not in needed: | |
| parts.append(imp) | |
| if needed or hoisted_imports: | |
| parts.append('') | |
| # Cross-module imports | |
| deps = module_deps.get(mod_name, set()) | |
| if deps: | |
| by_mod = defaultdict(list) | |
| for dep_mod, dep_sym in deps: | |
| by_mod[dep_mod].append(dep_sym) | |
| for dep_mod in sorted(by_mod): | |
| syms = sorted(by_mod[dep_mod]) | |
| parts.append(f"from .{dep_mod} import {', '.join(syms)}") | |
| parts.append('') | |
| # Symbol bodies (strip hoisted stdlib imports from source) | |
| for sn in sym_names: | |
| s = sym_by_name.get(sn) | |
| if not s: | |
| parts.append(f"# WARNING: {sn} not found in source") | |
| parts.append('') | |
| continue | |
| # Include decorator lines (AST lineno starts at the `def`/`class`, not decorator) | |
| actual_start = s.start_line | |
| if hasattr(s.node, 'decorator_list') and s.node.decorator_list: | |
| first_dec = s.node.decorator_list[0] | |
| actual_start = first_dec.lineno | |
| # Build source, stripping hoisted inline imports | |
| sym_lines = [] | |
| for line_idx in range(actual_start - 1, s.end_line): | |
| lineno_1based = line_idx + 1 | |
| if lineno_1based in strip_lines: | |
| continue | |
| sym_lines.append(lines[line_idx]) | |
| sym_source = '\n'.join(sym_lines) | |
| parts.append('') | |
| parts.append(sym_source) | |
| parts.append('') | |
| content = '\n'.join(parts) | |
| return ModuleFile( | |
| name=mod_name, | |
| symbols=sym_names, | |
| source=content, | |
| line_count=len(content.splitlines()), | |
| deps=deps, | |
| ) | |
| def generate_init( | |
| module_spec: dict[str, list[str]], | |
| existing_init: Optional[str], | |
| package_name: str, | |
| ) -> str: | |
| """Generate __init__.py preserving version and re-exporting all symbols.""" | |
| parts = [] | |
| parts.append(f'"""{package_name} - Auto-split package"""') | |
| parts.append('') | |
| # Preserve __version__ if it exists | |
| if existing_init: | |
| for line in existing_init.splitlines(): | |
| if line.strip().startswith('__version__'): | |
| parts.append(line) | |
| parts.append('') | |
| break | |
| # Re-export from each module | |
| for mod_name, sym_names in module_spec.items(): | |
| sym_list = ', '.join(sym_names) | |
| parts.append(f"from .{mod_name} import {sym_list}") | |
| parts.append('') | |
| # __all__ | |
| all_syms = [] | |
| for sym_names in module_spec.values(): | |
| all_syms.extend(sym_names) | |
| # Include __version__ if found | |
| if existing_init and '__version__' in existing_init: | |
| all_syms.append('__version__') | |
| parts.append(f"__all__ = {all_syms!r}") | |
| parts.append('') | |
| return '\n'.join(parts) | |
| # ── Verification ──────────────────────────────────────────────────────────── | |
| def verify_split( | |
| symbols: list[SymbolInfo], | |
| module_spec: dict[str, list[str]], | |
| module_deps: dict[str, set[tuple[str, str]]], | |
| generated: dict[str, ModuleFile], | |
| source: str, | |
| ) -> list[tuple[str, str]]: | |
| """ | |
| Run all verification checks. Returns list of (level, message) tuples. | |
| level is one of: 'error', 'warn', 'ok'. | |
| """ | |
| results = [] | |
| sym_by_name = {s.name: s for s in symbols} | |
| # 1. Completeness: every symbol assigned to exactly one module | |
| assigned = {} | |
| for mod, sym_names in module_spec.items(): | |
| for sn in sym_names: | |
| if sn in assigned: | |
| results.append(('error', f"DUPLICATE: '{sn}' in both '{assigned[sn]}' and '{mod}'")) | |
| assigned[sn] = mod | |
| for s in symbols: | |
| if s.name not in assigned: | |
| results.append(('error', f"UNASSIGNED: '{s.name}' (L{s.start_line}-{s.end_line}, {s.num_lines} lines)")) | |
| for mod, sym_names in module_spec.items(): | |
| for sn in sym_names: | |
| if sn not in sym_by_name: | |
| results.append(('error', f"MISSING: '{sn}' in spec for '{mod}' but not in source")) | |
| if not any(r[0] == 'error' for r in results if 'UNASSIGNED' in r[1] or 'DUPLICATE' in r[1] or 'MISSING' in r[1]): | |
| results.append(('ok', "All symbols assigned to exactly one module")) | |
| # 2. Circular dependencies | |
| cycles = detect_cycles(module_deps) | |
| if cycles: | |
| for cycle in cycles: | |
| results.append(('error', f"CYCLE: {' -> '.join(cycle)}")) | |
| else: | |
| results.append(('ok', "No circular dependencies")) | |
| # 3. Cross-module import targets exist | |
| for mod in module_spec: | |
| for dep_mod, dep_sym in module_deps.get(mod, set()): | |
| if dep_sym in module_spec.get(dep_mod, []): | |
| results.append(('ok', f"{mod} imports {dep_sym} from {dep_mod}")) | |
| else: | |
| results.append(('error', f"{mod} needs {dep_sym} from {dep_mod} — NOT IN THAT MODULE")) | |
| # 4. AST isomorphism check: re-parse each generated module, find each symbol, | |
| # hash it, and compare against the original's post-hoist hash. | |
| # If hashes match, the split is structurally identical (same AST minus position). | |
| for mod_name, mod_file in generated.items(): | |
| if mod_name == '__init__': | |
| continue | |
| try: | |
| split_tree = ast.parse(mod_file.source) | |
| except SyntaxError as e: | |
| results.append(('error', f"[{mod_name}] SYNTAX ERROR in generated module: {e}")) | |
| continue | |
| # Index split symbols by name | |
| split_syms = {} | |
| for split_node in ast.iter_child_nodes(split_tree): | |
| if isinstance(split_node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): | |
| split_syms[split_node.name] = split_node | |
| elif isinstance(split_node, ast.Assign): | |
| for t in split_node.targets: | |
| if isinstance(t, ast.Name): | |
| split_syms[t.id] = split_node | |
| break | |
| for sn in mod_file.symbols: | |
| s = sym_by_name.get(sn) | |
| if not s: | |
| continue | |
| split_node = split_syms.get(sn) | |
| if not split_node: | |
| results.append(('error', f"[{mod_name}] {sn} NOT FOUND in generated module AST")) | |
| continue | |
| # Hash the split version (no strip needed — stdlib imports already removed) | |
| split_hash = compute_structural_hash(mod_file.source, split_node) | |
| original_posthoist = s.structural_hash_posthoist | |
| if split_hash == original_posthoist: | |
| if s.structural_hash != original_posthoist: | |
| results.append(('ok', | |
| f"[{mod_name}] {sn} AST isomorphic ✓ " | |
| f"(original: {s.structural_hash} → hoisted: {original_posthoist} == split: {split_hash})")) | |
| else: | |
| results.append(('ok', | |
| f"[{mod_name}] {sn} AST isomorphic ✓ (hash: {split_hash})")) | |
| else: | |
| results.append(('error', | |
| f"[{mod_name}] {sn} AST MISMATCH! " | |
| f"original(posthoist): {original_posthoist} != split: {split_hash}")) | |
| # 5. Inline import hoisting report | |
| for s in symbols: | |
| if s.inline_imports: | |
| mod = assigned.get(s.name, '?') | |
| for imp_text, lineno, end_lineno in s.inline_imports: | |
| if is_stdlib_import(imp_text): | |
| results.append(('ok', f"[{mod}] {s.name} L{lineno}: hoisted stdlib '{imp_text}' to top-level")) | |
| else: | |
| results.append(('ok', f"[{mod}] {s.name} L{lineno}: non-stdlib inline import kept in body: {imp_text}")) | |
| return results | |
| # ── Output formatting ─────────────────────────────────────────────────────── | |
| def print_analysis( | |
| symbols: list[SymbolInfo], | |
| module_spec: dict[str, list[str]], | |
| module_deps: dict[str, set[tuple[str, str]]], | |
| generated: dict[str, ModuleFile], | |
| source: str, | |
| source_path: Path, | |
| ): | |
| """Print the full analysis report.""" | |
| sym_by_name = {s.name: s for s in symbols} | |
| print("=" * 70) | |
| print("MODULE SPLITTER - AST Analysis Report") | |
| print(f"Source: {source_path}") | |
| print(f"Symbols: {len(symbols)} | Modules: {len(module_spec)}") | |
| print("=" * 70) | |
| # Symbol table | |
| print("\n── Symbol Table ──") | |
| print(f" {'Name':<35} {'Kind':<10} {'Lines':>8} {'Hash':>14} {'Module':<12}") | |
| print(f" {'─' * 35} {'─' * 10} {'─' * 8} {'─' * 14} {'─' * 12}") | |
| assigned = {} | |
| for mod, syms in module_spec.items(): | |
| for sn in syms: | |
| assigned[sn] = mod | |
| for s in symbols: | |
| mod = assigned.get(s.name, '???') | |
| line_range = f"L{s.start_line}-{s.end_line}" | |
| print(f" {s.name:<35} {s.kind:<10} {line_range:>8} {s.structural_hash:>14} {mod:<12}") | |
| if s.signature: | |
| sig = s.signature | |
| if len(sig) > 65: | |
| sig = sig[:62] + "..." | |
| print(f" {sig}") | |
| if s.decorators: | |
| print(f" decorators: {', '.join(s.decorators)}") | |
| if s.bases: | |
| print(f" bases: {', '.join(s.bases)}") | |
| # Dependencies | |
| print("\n── Module Dependencies ──") | |
| topo = topological_sort(module_deps, list(module_spec.keys())) | |
| for mod in topo: | |
| deps = module_deps.get(mod, set()) | |
| if deps: | |
| by_mod = defaultdict(list) | |
| for d_mod, d_sym in deps: | |
| by_mod[d_mod].append(d_sym) | |
| dep_strs = [f"{m}({', '.join(sorted(s))})" for m, s in sorted(by_mod.items())] | |
| print(f" {mod} → {', '.join(dep_strs)}") | |
| else: | |
| print(f" {mod} → (leaf)") | |
| print(f"\n Topological order: {' → '.join(topo)}") | |
| # Generated files | |
| print("\n── Generated Files ──") | |
| total = 0 | |
| for mod_name in list(module_spec.keys()) + ['__init__']: | |
| mf = generated.get(mod_name) | |
| if mf: | |
| total += mf.line_count | |
| print(f" {mf.name + '.py':<20} {mf.line_count:>6} lines [{', '.join(mf.symbols[:4])}{'...' if len(mf.symbols) > 4 else ''}]") | |
| elif mod_name == '__init__': | |
| init_src = generated.get('__init__') | |
| if init_src: | |
| lc = init_src.line_count if isinstance(init_src, ModuleFile) else len(init_src.splitlines()) if isinstance(init_src, str) else 0 | |
| total += lc | |
| print(f" __init__.py {lc:>6} lines [re-exports]") | |
| original_lines = len(source.splitlines()) | |
| print(f"\n Original: {original_lines:,} lines (1 file)") | |
| print(f" Split: {total:,} lines ({len(generated)} files)") | |
| print(f" Overhead: {total - original_lines:+,} lines") | |
| def print_verification(results: list[tuple[str, str]]): | |
| """Print verification results.""" | |
| print("\n── Verification ──") | |
| icons = {'ok': '✅', 'warn': '⚠️ ', 'error': '❌'} | |
| errors = [r for r in results if r[0] == 'error'] | |
| warns = [r for r in results if r[0] == 'warn'] | |
| oks = [r for r in results if r[0] == 'ok'] | |
| # Show errors and warnings always, oks only in summary | |
| for level, msg in errors: | |
| print(f" {icons[level]} {msg}") | |
| for level, msg in warns: | |
| print(f" {icons[level]} {msg}") | |
| print(f"\n Summary: {len(oks)} ok, {len(warns)} warnings, {len(errors)} errors") | |
| if not errors: | |
| print(" ✅ Split is safe to apply") | |
| else: | |
| print(" ❌ Fix errors before applying") | |
| # ── Main ──────────────────────────────────────────────────────────────────── | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="General-purpose AST-based Python module splitter", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| ) | |
| parser.add_argument('source', type=Path, help='Python source file to split') | |
| parser.add_argument('--auto', action='store_true', help='Auto-discover module clusters') | |
| parser.add_argument('--spec', type=Path, help='JSON file with module spec: {"mod": ["sym1", ...]}') | |
| parser.add_argument('--write', action='store_true', help='Actually write the split files') | |
| parser.add_argument('--verify', action='store_true', help='Run verification checks') | |
| parser.add_argument('--outdir', type=Path, default=None, help='Output directory (default: same as source)') | |
| parser.add_argument('--package', type=str, default=None, help='Package name (default: parent dir name)') | |
| parser.add_argument('--symbols', action='store_true', help='Only print symbol table, no split') | |
| parser.add_argument('--deps', action='store_true', help='Only print dependency graph') | |
| parser.add_argument('--hashes', action='store_true', help='Print structural hashes for drift detection') | |
| parser.add_argument('--isomorphic', action='store_true', | |
| help='End-to-end AST isomorphism test: write split to tmpdir, ' | |
| 're-parse, hash every symbol independently, compare to original') | |
| args = parser.parse_args() | |
| source_path = args.source | |
| if not source_path.exists(): | |
| print(f"ERROR: {source_path} not found", file=sys.stderr) | |
| sys.exit(1) | |
| outdir = args.outdir or source_path.parent | |
| package_name = args.package or outdir.name | |
| # Parse | |
| source, tree = parse_source(source_path) | |
| symbols, top_level_imports = analyze_file(source, tree) | |
| # Symbols-only mode | |
| if args.symbols: | |
| print(f"{'Name':<35} {'Kind':<10} {'Lines':>10} {'Hash':>14} {'Refs'}") | |
| print(f"{'─' * 35} {'─' * 10} {'─' * 10} {'─' * 14} {'─' * 30}") | |
| for s in symbols: | |
| refs = ', '.join(sorted(s.references)) if s.references else '(none)' | |
| lr = f"L{s.start_line}-{s.end_line}" | |
| print(f"{s.name:<35} {s.kind:<10} {lr:>10} {s.structural_hash:>14} {refs}") | |
| if s.signature: | |
| print(f" {s.signature}") | |
| if s.inline_imports: | |
| for imp_text, lineno, end_lineno in s.inline_imports: | |
| stdlib_tag = "stdlib→hoist" if is_stdlib_import(imp_text) else "non-stdlib" | |
| print(f" [inline L{lineno}] ({stdlib_tag}) {imp_text}") | |
| return | |
| # Deps-only mode | |
| if args.deps: | |
| print("Dependency graph (symbol → references):") | |
| for s in symbols: | |
| refs = ', '.join(sorted(s.references)) if s.references else '(none)' | |
| print(f" {s.name} → {refs}") | |
| return | |
| # Hashes-only mode | |
| if args.hashes: | |
| print(f"Structural hashes for {source_path}:") | |
| for s in symbols: | |
| if s.structural_hash != s.structural_hash_posthoist: | |
| print(f" {s.structural_hash} {s.name} (L{s.start_line}-{s.end_line})") | |
| print(f" {s.structural_hash_posthoist} └─ post-hoist (stdlib imports stripped)") | |
| else: | |
| print(f" {s.structural_hash} {s.name} (L{s.start_line}-{s.end_line})") | |
| return | |
| # Determine module spec | |
| if args.spec: | |
| with open(args.spec) as f: | |
| module_spec = json.load(f) | |
| elif args.auto: | |
| module_spec = auto_cluster(symbols) | |
| else: | |
| # Default: use niwa-specific spec | |
| module_spec = { | |
| "models": ["ConflictType", "Edit", "ConflictAnalysis", "EditResult"], | |
| "prompts": ["LLM_SYSTEM_PROMPT", "COMMAND_HELP", "ERROR_PROMPTS"], | |
| "db": ["Niwa"], | |
| "hooks": ["generate_claude_hooks_config", "get_niwa_usage_guide", | |
| "handle_hook_event", "setup_claude_hooks"], | |
| "cli": ["print_error", "print_command_help", "main"], | |
| } | |
| # Build dependency graph | |
| module_deps = build_module_deps(symbols, module_spec) | |
| # Generate files | |
| generated = {} | |
| for mod_name, sym_names in module_spec.items(): | |
| mf = generate_module( | |
| mod_name, sym_names, source, symbols, | |
| top_level_imports, module_deps, package_name, | |
| ) | |
| generated[mod_name] = mf | |
| # Generate __init__.py | |
| existing_init = None | |
| init_path = outdir / "__init__.py" | |
| if init_path.exists(): | |
| existing_init = init_path.read_text() | |
| init_source = generate_init(module_spec, existing_init, package_name) | |
| generated['__init__'] = ModuleFile( | |
| name='__init__', | |
| symbols=[], | |
| source=init_source, | |
| line_count=len(init_source.splitlines()), | |
| ) | |
| # Print analysis | |
| print_analysis(symbols, module_spec, module_deps, generated, source, source_path) | |
| # Verification | |
| if args.verify or args.write: | |
| results = verify_split(symbols, module_spec, module_deps, generated, source) | |
| print_verification(results) | |
| if args.write and any(r[0] == 'error' for r in results): | |
| print("\n❌ Cannot write: fix errors first") | |
| sys.exit(1) | |
| # Isomorphic test: independent end-to-end AST hash comparison | |
| if args.isomorphic: | |
| import tempfile | |
| import shutil | |
| print("\n── Isomorphic AST Test ──") | |
| print(" Writing split to temp dir, re-parsing, hashing independently...\n") | |
| tmpdir = Path(tempfile.mkdtemp(prefix='split_iso_')) | |
| try: | |
| # Write generated files to tmpdir | |
| for mod_name, mf in generated.items(): | |
| if mod_name == '__init__': | |
| path = tmpdir / "__init__.py" | |
| else: | |
| path = tmpdir / f"{mod_name}.py" | |
| path.write_text(mf.source if isinstance(mf, ModuleFile) else mf) | |
| # Build expected hashes from original (post-hoist) | |
| expected_hashes = {} | |
| for s in symbols: | |
| expected_hashes[s.name] = s.structural_hash_posthoist | |
| # Re-parse each split file and hash symbols independently | |
| ok_count = 0 | |
| fail_count = 0 | |
| for py_file in sorted(tmpdir.glob('*.py')): | |
| if py_file.name == '__init__.py': | |
| continue | |
| split_src = py_file.read_text() | |
| try: | |
| split_tree = ast.parse(split_src) | |
| except SyntaxError as e: | |
| print(f" ✗ {py_file.name}: SYNTAX ERROR: {e}") | |
| fail_count += 1 | |
| continue | |
| for split_node in ast.iter_child_nodes(split_tree): | |
| name = None | |
| if isinstance(split_node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): | |
| name = split_node.name | |
| elif isinstance(split_node, ast.Assign): | |
| for t in split_node.targets: | |
| if isinstance(t, ast.Name): | |
| name = t.id | |
| break | |
| if name and name in expected_hashes: | |
| split_hash = compute_structural_hash(split_src, split_node) | |
| expected = expected_hashes[name] | |
| if split_hash == expected: | |
| ok_count += 1 | |
| print(f" ✓ {name}: {split_hash}") | |
| else: | |
| fail_count += 1 | |
| print(f" ✗ {name}: split={split_hash} != expected={expected}") | |
| # Check coverage: every original symbol should appear in some split file | |
| found_in_split = set() | |
| for py_file in sorted(tmpdir.glob('*.py')): | |
| if py_file.name == '__init__.py': | |
| continue | |
| split_tree = ast.parse(py_file.read_text()) | |
| for split_node in ast.iter_child_nodes(split_tree): | |
| if isinstance(split_node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): | |
| found_in_split.add(split_node.name) | |
| elif isinstance(split_node, ast.Assign): | |
| for t in split_node.targets: | |
| if isinstance(t, ast.Name): | |
| found_in_split.add(t.id) | |
| break | |
| missing = set(expected_hashes.keys()) - found_in_split | |
| if missing: | |
| for name in sorted(missing): | |
| print(f" ✗ {name}: NOT FOUND in any split file") | |
| fail_count += 1 | |
| print(f"\n Result: {ok_count} match, {fail_count} mismatch, " | |
| f"{len(expected_hashes)} symbols total") | |
| if fail_count == 0: | |
| print(" ✅ Split is isomorphic — identical AST structure proven via SHA-256") | |
| else: | |
| print(" ❌ AST isomorphism FAILED — split diverges from original") | |
| sys.exit(1) | |
| finally: | |
| shutil.rmtree(tmpdir) | |
| if not args.write: | |
| return | |
| # Write or dry-run | |
| if args.write: | |
| print("\n── Writing Files ──") | |
| outdir.mkdir(parents=True, exist_ok=True) | |
| # Backup original BEFORE writing (source may be inside outdir) | |
| backup = source_path.with_suffix('.py.bak') | |
| if source_path.exists() and outdir.resolve() == source_path.resolve().parent: | |
| source_path.rename(backup) | |
| print(f" 📦 Backed up original to {backup}") | |
| for mod_name, mf in generated.items(): | |
| if mod_name == '__init__': | |
| path = outdir / "__init__.py" | |
| else: | |
| path = outdir / f"{mod_name}.py" | |
| path.write_text(mf.source if isinstance(mf, ModuleFile) else mf) | |
| print(f" ✅ {path} ({mf.line_count} lines)") | |
| else: | |
| # Dry run — print generated files | |
| for mod_name in list(module_spec.keys()) + ['__init__']: | |
| mf = generated.get(mod_name) | |
| if not mf: | |
| continue | |
| fname = f"{outdir}/{mod_name}.py" if mod_name != '__init__' else f"{outdir}/__init__.py" | |
| print(f"\n{'═' * 70}") | |
| print(f"FILE: {fname}") | |
| print(f"{'═' * 70}") | |
| src = mf.source if isinstance(mf, ModuleFile) else mf | |
| print(src) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment