Skip to content

Instantly share code, notes, and snippets.

@secemp9
Created January 29, 2026 11:00
Show Gist options
  • Select an option

  • Save secemp9/09a20c0cbaf0942f36109755c540863b to your computer and use it in GitHub Desktop.

Select an option

Save secemp9/09a20c0cbaf0942f36109755c540863b to your computer and use it in GitHub Desktop.
AST Python splitter to module
#!/usr/bin/env python3
"""
General-purpose AST-based Python module splitter.
Analyzes any monolithic .py file, builds a full dependency graph with structural
fingerprints, auto-discovers natural clusters, and produces split module files.
Capabilities:
- AST parsing with full symbol table (ast + inspect-style signatures)
- Structural hashing (hashlib) to fingerprint each symbol for drift detection
- Dependency graph with cycle detection (tarjan-style)
- Auto-clustering via connectivity analysis when no spec is provided
- Import filtering: only carries over what each module actually uses
- Inline import hoisting: detects `import X` inside functions, promotes to top
- Decorator-aware extraction: includes @dataclass, @property, etc.
- Dry run with diff-style output or --write to commit
Usage:
python3 tools/split_modules.py <source.py> [options]
# Auto-discover clusters (no config needed):
python3 tools/split_modules.py niwa/cli.py --auto
# Use a split spec (JSON or inline):
python3 tools/split_modules.py niwa/cli.py --spec split_spec.json
# Dry run (default): prints analysis + file contents to stdout
python3 tools/split_modules.py niwa/cli.py --auto
# Verify only (no file output):
python3 tools/split_modules.py niwa/cli.py --auto --verify
# Actually write the files:
python3 tools/split_modules.py niwa/cli.py --auto --write
# Output directory (default: same as source):
python3 tools/split_modules.py niwa/cli.py --auto --outdir niwa/
"""
import argparse
import ast
import hashlib
import inspect
import json
import re
import sys
import textwrap
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
# ── Data structures ─────────────────────────────────────────────────────────
@dataclass
class SymbolInfo:
"""Full metadata for a top-level symbol extracted via AST."""
name: str
kind: str # 'function', 'async_function', 'class', 'constant'
node: ast.AST
start_line: int
end_line: int
num_lines: int
decorators: list[str] = field(default_factory=list)
bases: list[str] = field(default_factory=list) # class bases
signature: Optional[str] = None # function signature string
docstring: Optional[str] = None
references: set[str] = field(default_factory=set) # other top-level names used
inline_imports: list[tuple[str, int, int]] = field(default_factory=list) # (text, lineno, end_lineno)
structural_hash: str = "" # content-based fingerprint (original)
structural_hash_posthoist: str = "" # hash after stripping stdlib inline imports
module: Optional[str] = None # assigned module after clustering
@property
def is_class(self) -> bool:
return self.kind == 'class'
@property
def is_function(self) -> bool:
return self.kind in ('function', 'async_function')
@property
def is_constant(self) -> bool:
return self.kind == 'constant'
@dataclass
class ModuleFile:
"""A generated output module."""
name: str
symbols: list[str]
source: str = ""
line_count: int = 0
deps: set[tuple[str, str]] = field(default_factory=set) # (module, symbol)
# ── AST analysis ────────────────────────────────────────────────────────────
def parse_source(path: Path) -> tuple[str, ast.Module]:
"""Parse a Python source file into source text + AST."""
source = path.read_text()
tree = ast.parse(source, filename=str(path))
return source, tree
def compute_structural_hash(source: str, node: ast.AST, strip_stdlib_imports: bool = False) -> str:
"""
Compute a content-based fingerprint for a symbol's AST subtree.
Uses ast.dump (which normalizes whitespace/comments away) hashed with sha256.
Two symbols with identical structure produce the same hash, even if
whitespace or comments differ. Useful for detecting drift between the
original monolith and the split files.
If strip_stdlib_imports=True, removes stdlib import nodes from function/method
bodies before hashing. This produces the "expected" hash after hoisting, so
original (stripped) == split (already stripped) proves isomorphism.
"""
if strip_stdlib_imports:
node = _strip_stdlib_import_nodes(node)
# ast.dump gives a canonical string repr of the subtree
canonical = ast.dump(node, annotate_fields=True, include_attributes=False)
return hashlib.sha256(canonical.encode()).hexdigest()[:12]
def _strip_stdlib_import_nodes(node: ast.AST) -> ast.AST:
"""Return a deep copy of node with stdlib import statements removed from ALL bodies.
Recursively strips stdlib imports from every statement list (function bodies,
if/for/while/with/try bodies, except handlers, etc.) at any nesting depth.
This lets us hash the "post-hoist" version of a symbol from the original AST,
so it can be compared against the hash of the same symbol in the split file
(where those imports were physically removed).
"""
import copy
node = copy.deepcopy(node)
def _filter_body(body: list) -> list:
return [stmt for stmt in body if not _is_stdlib_import_stmt(stmt)]
def _recurse(n):
"""Strip stdlib imports from any node that has statement lists, recursively."""
# All AST node types that contain statement lists
body_attrs = ('body', 'orelse', 'finalbody')
for attr in body_attrs:
if hasattr(n, attr):
stmts = getattr(n, attr)
if isinstance(stmts, list):
setattr(n, attr, _filter_body(stmts))
for child in getattr(n, attr):
_recurse(child)
# Try/ExceptHandler handlers
if hasattr(n, 'handlers'):
for handler in n.handlers:
handler.body = _filter_body(handler.body)
for child in handler.body:
_recurse(child)
_recurse(node)
return node
def _is_stdlib_import_stmt(stmt: ast.AST) -> bool:
"""Check if an AST statement node is a stdlib import."""
if isinstance(stmt, ast.Import):
return all(
(alias.asname or alias.name).split('.')[0] in STDLIB_MODULES
for alias in stmt.names
)
elif isinstance(stmt, ast.ImportFrom):
if stmt.module:
return stmt.module.split('.')[0] in STDLIB_MODULES
return False
def extract_signature(node: ast.AST) -> Optional[str]:
"""Extract a human-readable function signature from AST."""
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
return None
args = node.args
parts = []
# Positional args
num_defaults = len(args.defaults)
num_positional = len(args.args)
for i, arg in enumerate(args.args):
name = arg.arg
annotation = ""
if arg.annotation:
annotation = f": {ast.unparse(arg.annotation)}"
# Check if this arg has a default
default_idx = i - (num_positional - num_defaults)
if default_idx >= 0:
default = ast.unparse(args.defaults[default_idx])
parts.append(f"{name}{annotation}={default}")
else:
parts.append(f"{name}{annotation}")
# *args
if args.vararg:
va = args.vararg
ann = f": {ast.unparse(va.annotation)}" if va.annotation else ""
parts.append(f"*{va.arg}{ann}")
elif args.kwonlyargs:
parts.append("*")
# keyword-only
for i, arg in enumerate(args.kwonlyargs):
ann = f": {ast.unparse(arg.annotation)}" if arg.annotation else ""
if i < len(args.kw_defaults) and args.kw_defaults[i] is not None:
default = ast.unparse(args.kw_defaults[i])
parts.append(f"{arg.arg}{ann}={default}")
else:
parts.append(f"{arg.arg}{ann}")
# **kwargs
if args.kwarg:
kw = args.kwarg
ann = f": {ast.unparse(kw.annotation)}" if kw.annotation else ""
parts.append(f"**{kw.arg}{ann}")
# Return annotation
ret = ""
if node.returns:
ret = f" -> {ast.unparse(node.returns)}"
prefix = "async def" if isinstance(node, ast.AsyncFunctionDef) else "def"
return f"{prefix} {node.name}({', '.join(parts)}){ret}"
def extract_decorators(node: ast.AST) -> list[str]:
"""Extract decorator names from a class or function."""
if not hasattr(node, 'decorator_list'):
return []
decorators = []
for dec in node.decorator_list:
try:
decorators.append(ast.unparse(dec))
except Exception:
decorators.append("(unknown)")
return decorators
def extract_docstring(node: ast.AST) -> Optional[str]:
"""Extract the docstring from a function/class/module node."""
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Module)):
if (node.body and isinstance(node.body[0], ast.Expr)
and isinstance(node.body[0].value, (ast.Constant, ast.Str))):
val = node.body[0].value
if isinstance(val, ast.Constant) and isinstance(val.value, str):
return val.value
elif isinstance(val, ast.Str):
return val.s
return None
def extract_class_bases(node: ast.AST) -> list[str]:
"""Extract base class names."""
if not isinstance(node, ast.ClassDef):
return []
return [ast.unparse(b) for b in node.bases]
def find_inline_imports(node: ast.AST) -> list[tuple[str, int, int]]:
"""Find import statements inside function/method bodies (not at module level).
Returns list of (import_text, lineno, end_lineno) tuples.
"""
imports = []
def _collect(parent):
for child in ast.walk(parent):
if child is parent:
continue
if isinstance(child, (ast.Import, ast.ImportFrom)):
try:
text = ast.unparse(child)
imports.append((text, child.lineno, child.end_lineno or child.lineno))
except Exception:
pass
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
_collect(node)
elif isinstance(node, ast.ClassDef):
for method in node.body:
if isinstance(method, (ast.FunctionDef, ast.AsyncFunctionDef)):
_collect(method)
return imports
def _get_stdlib_modules() -> frozenset[str]:
"""Return the set of standard library module names for the running Python."""
return frozenset(sys.stdlib_module_names)
STDLIB_MODULES = _get_stdlib_modules()
def is_stdlib_import(import_text: str) -> bool:
"""Check if an import statement refers to a stdlib module.
Handles both 'import X' and 'from X import Y' forms.
"""
import_text = import_text.strip()
if import_text.startswith('from '):
# 'from datetime import datetime' -> 'datetime'
parts = import_text.split()
if len(parts) >= 2:
mod = parts[1].split('.')[0]
return mod in STDLIB_MODULES
elif import_text.startswith('import '):
# 'import sys' -> 'sys', 'import os.path' -> 'os'
parts = import_text.split()
if len(parts) >= 2:
mod = parts[1].split('.')[0].rstrip(',')
return mod in STDLIB_MODULES
return False
def collect_references(node: ast.AST, known_names: set[str], own_name: str) -> set[str]:
"""Find which known top-level names are referenced inside a node."""
target_names = known_names - {own_name}
refs = set()
class RefVisitor(ast.NodeVisitor):
def visit_Name(self, n):
if n.id in target_names:
refs.add(n.id)
self.generic_visit(n)
RefVisitor().visit(node)
return refs
def analyze_file(source: str, tree: ast.Module) -> tuple[list[SymbolInfo], list[ast.AST]]:
"""
Full AST analysis: extract all top-level symbols with metadata,
and all top-level imports.
"""
symbols = []
top_level_imports = []
lines = source.splitlines()
# Collect top-level defs
for node in ast.iter_child_nodes(tree):
if isinstance(node, (ast.Import, ast.ImportFrom)):
top_level_imports.append(node)
continue
# try: import X / except ImportError: ... blocks
if isinstance(node, ast.Try) and node.body and all(
isinstance(stmt, (ast.Import, ast.ImportFrom)) for stmt in node.body
):
top_level_imports.append(node)
continue
info = None
if isinstance(node, ast.FunctionDef):
info = SymbolInfo(
name=node.name,
kind='function',
node=node,
start_line=node.lineno,
end_line=node.end_lineno,
num_lines=node.end_lineno - node.lineno + 1,
decorators=extract_decorators(node),
signature=extract_signature(node),
docstring=extract_docstring(node),
inline_imports=find_inline_imports(node),
structural_hash=compute_structural_hash(source, node),
)
elif isinstance(node, ast.AsyncFunctionDef):
info = SymbolInfo(
name=node.name,
kind='async_function',
node=node,
start_line=node.lineno,
end_line=node.end_lineno,
num_lines=node.end_lineno - node.lineno + 1,
decorators=extract_decorators(node),
signature=extract_signature(node),
docstring=extract_docstring(node),
inline_imports=find_inline_imports(node),
structural_hash=compute_structural_hash(source, node),
)
elif isinstance(node, ast.ClassDef):
info = SymbolInfo(
name=node.name,
kind='class',
node=node,
start_line=node.lineno,
end_line=node.end_lineno,
num_lines=node.end_lineno - node.lineno + 1,
decorators=extract_decorators(node),
bases=extract_class_bases(node),
docstring=extract_docstring(node),
inline_imports=find_inline_imports(node),
structural_hash=compute_structural_hash(source, node),
)
elif isinstance(node, ast.Assign):
for t in node.targets:
if isinstance(t, ast.Name):
info = SymbolInfo(
name=t.id,
kind='constant',
node=node,
start_line=node.lineno,
end_line=node.end_lineno,
num_lines=node.end_lineno - node.lineno + 1,
structural_hash=compute_structural_hash(source, node),
)
break # only first target
if info:
# Compute post-hoist hash (strips stdlib inline imports from AST before hashing)
info.structural_hash_posthoist = compute_structural_hash(
source, info.node, strip_stdlib_imports=True
)
symbols.append(info)
# Now compute cross-references
known_names = {s.name for s in symbols}
for sym in symbols:
sym.references = collect_references(sym.node, known_names, sym.name)
return symbols, top_level_imports
# ── Clustering ──────────────────────────────────────────────────────────────
def auto_cluster(symbols: list[SymbolInfo]) -> dict[str, list[str]]:
"""
Auto-discover module clusters via graph-based hierarchical agglomerative
clustering with modularity-based stopping.
General-purpose — no keyword matching, no name-based heuristics.
Works on any Python codebase regardless of domain.
Algorithm:
1. Build a weighted undirected affinity graph from the directed dep graph:
- Direct reference (A uses B): weight 1.0
- Shared reference (A and B both use C): weight 0.5
- Co-referenced (C uses both A and B): weight 0.3
- Kind affinity (both enums, both dataclasses): weight 0.2
2. Seed: group enums + dataclasses together (universal convention)
3. Hierarchical agglomerative merge: repeatedly merge the two clusters
with the highest inter-cluster affinity, normalized by size.
4. Stop when the best merge score drops below a threshold (modularity
gain becomes negative) or all clusters are within a target size range.
5. Name modules from their contents (heaviest symbol, dominant class, etc.)
"""
if not symbols:
return {}
names = [s.name for s in symbols]
sym_by_name = {s.name: s for s in symbols}
n = len(names)
name_to_idx = {name: i for i, name in enumerate(names)}
# ── Step 1: Build affinity matrix (symmetric) ──
# affinity[i][j] = how strongly symbol i and j should be in the same module
affinity = [[0.0] * n for _ in range(n)]
# 1a. Direct references: A references B → edge(A, B) += 1.0
for s in symbols:
i = name_to_idx[s.name]
for ref in s.references:
if ref in name_to_idx:
j = name_to_idx[ref]
affinity[i][j] += 1.0
affinity[j][i] += 1.0
# 1b. Shared references: A and B both reference C → edge(A, B) += 0.5
# Build reverse index: who references each symbol?
referenced_by = defaultdict(set)
for s in symbols:
for ref in s.references:
if ref in name_to_idx:
referenced_by[ref].add(s.name)
for target, referrers in referenced_by.items():
referrer_list = sorted(referrers)
for ai in range(len(referrer_list)):
for bi in range(ai + 1, len(referrer_list)):
i = name_to_idx[referrer_list[ai]]
j = name_to_idx[referrer_list[bi]]
affinity[i][j] += 0.5
affinity[j][i] += 0.5
# 1c. Co-referenced: C references both A and B → edge(A, B) += 0.3
for s in symbols:
refs = sorted(r for r in s.references if r in name_to_idx)
for ai in range(len(refs)):
for bi in range(ai + 1, len(refs)):
i = name_to_idx[refs[ai]]
j = name_to_idx[refs[bi]]
affinity[i][j] += 0.3
affinity[j][i] += 0.3
# 1d. Kind affinity: same structural kind gets a small boost
for ai in range(n):
for bi in range(ai + 1, n):
sa, sb = symbols[ai], symbols[bi]
# Enum+Enum or dataclass+dataclass
a_is_data = (sa.is_class and ('Enum' in sa.bases or
any(d in ('dataclass', 'dataclasses.dataclass') for d in sa.decorators)))
b_is_data = (sb.is_class and ('Enum' in sb.bases or
any(d in ('dataclass', 'dataclasses.dataclass') for d in sb.decorators)))
if a_is_data and b_is_data:
affinity[ai][bi] += 0.2
affinity[bi][ai] += 0.2
# ── Step 2: Initialize clusters ──
# Each symbol starts as its own cluster (index-based)
# cluster_id -> set of symbol indices
clusters: dict[int, set[int]] = {i: {i} for i in range(n)}
next_cluster_id = n
# Seed: merge all enums + dataclasses into one initial cluster
data_type_indices = []
for s in symbols:
if s.is_class and ('Enum' in s.bases or
any(d in ('dataclass', 'dataclasses.dataclass') for d in s.decorators)):
data_type_indices.append(name_to_idx[s.name])
if len(data_type_indices) > 1:
# Merge all data type clusters into one
merged = set()
ids_to_remove = []
for idx in data_type_indices:
# Find which cluster contains this index
for cid, members in clusters.items():
if idx in members:
merged |= members
ids_to_remove.append(cid)
break
for cid in set(ids_to_remove):
del clusters[cid]
clusters[next_cluster_id] = merged
next_cluster_id += 1
# ── Step 3: Hierarchical agglomerative merging ──
# Total edge weight in the graph (for modularity calculation)
total_weight = sum(affinity[i][j] for i in range(n) for j in range(i + 1, n))
if total_weight == 0:
total_weight = 1.0 # avoid division by zero for isolated symbols
def cluster_affinity(c1: set[int], c2: set[int]) -> float:
"""Average-linkage affinity between two clusters."""
if not c1 or not c2:
return 0.0
total = sum(affinity[i][j] for i in c1 for j in c2)
return total / (len(c1) * len(c2))
def modularity_gain(c1: set[int], c2: set[int]) -> float:
"""Estimate modularity gain from merging c1 and c2.
Based on Newman-Girvan modularity: edges within the merged cluster
minus expected edges if connections were random.
"""
inter = sum(affinity[i][j] for i in c1 for j in c2)
# Degree of each cluster (sum of all edges from cluster to anywhere)
deg1 = sum(affinity[i][j] for i in c1 for j in range(n) if j not in c1)
deg2 = sum(affinity[i][j] for i in c2 for j in range(n) if j not in c2)
# Modularity delta: actual inter-edges minus expected
expected = (deg1 * deg2) / (2 * total_weight) if total_weight > 0 else 0
return inter - expected
# Size constraints
# Target: each module should be a manageable chunk. The original file is
# being split because it's too large, so aim for modules well below the
# original size. We use original_lines/3 as a soft cap to encourage
# splitting into at least 3 modules, with an absolute floor of 200 lines
# (don't bother splitting tiny files) and ceiling of 1500 lines.
original_lines = sum(s.num_lines for s in symbols)
min_lines = 50
max_lines = max(200, min(1500, original_lines // 3))
def cluster_lines(c: set[int]) -> int:
return sum(symbols[i].num_lines for i in c)
# Merge loop
max_iterations = n * n # safety bound
for _ in range(max_iterations):
if len(clusters) <= 1:
break
# Find best merge candidate
best_score = float('-inf')
best_pair = None
cluster_ids = list(clusters.keys())
for ai in range(len(cluster_ids)):
for bi in range(ai + 1, len(cluster_ids)):
cid_a, cid_b = cluster_ids[ai], cluster_ids[bi]
ca, cb = clusters[cid_a], clusters[cid_b]
# Skip if merge would exceed max size
merged_lines = cluster_lines(ca) + cluster_lines(cb)
if merged_lines > max_lines:
continue
# Score = modularity gain + affinity, penalized by size
mg = modularity_gain(ca, cb)
aff = cluster_affinity(ca, cb)
# Size penalty: as merged size approaches max_lines, penalize
# heavily so large clusters resist absorbing more symbols.
# Ratio 0→1 maps to penalty 0→-∞ (via exponential).
size_ratio = merged_lines / max_lines
size_penalty = (size_ratio ** 3) * 2.0 # cubic ramp
# Combined score: modularity-driven, affinity tiebreaker, size brake
score = mg + aff * 0.1 - size_penalty
if score > best_score:
best_score = score
best_pair = (cid_a, cid_b)
# Stop if no beneficial merge exists
if best_pair is None or best_score < 0:
break
# Merge
cid_a, cid_b = best_pair
merged = clusters[cid_a] | clusters[cid_b]
del clusters[cid_a]
del clusters[cid_b]
clusters[next_cluster_id] = merged
next_cluster_id += 1
# ── Step 4: Absorb tiny clusters ──
# Clusters below min_lines get merged into their most-connected neighbor
changed = True
while changed:
changed = False
tiny = [(cid, members) for cid, members in clusters.items()
if cluster_lines(members) < min_lines and len(clusters) > 1]
for cid, members in tiny:
# Find most-connected other cluster
best_target = None
best_aff = -1
for other_cid, other_members in clusters.items():
if other_cid == cid:
continue
aff = cluster_affinity(members, other_members)
if aff > best_aff:
best_aff = aff
best_target = other_cid
if best_target is not None:
clusters[best_target] |= members
del clusters[cid]
changed = True
break # restart scan since dict changed
# ── Step 5: Name modules from contents ──
result = {}
used_names = set()
for cid, member_indices in clusters.items():
member_syms = [symbols[i] for i in sorted(member_indices)]
name = _derive_module_name(member_syms, used_names)
used_names.add(name)
result[name] = [s.name for s in member_syms]
return result
def _derive_module_name(member_syms: list[SymbolInfo], used_names: set[str]) -> str:
"""Derive a module name from the symbols inside a cluster.
Strategy: find the name that best describes the *group*, not just the
largest member. A module named after one function is a code smell — it
means the cluster has no coherent identity.
Priority:
1. All enums/dataclasses → "models"
2. Dominant class (>50% of lines) → snake_case(ClassName)
3. All constants → "constants"
4. Find common semantic root across member names (longest common prefix/suffix)
5. Name after the most-referenced symbol (the "hub" others depend on)
6. Fallback → "module_N"
"""
if not member_syms:
return _unique_name("module", used_names)
total_lines = sum(s.num_lines for s in member_syms)
# 1. All data types → "models"
all_data = all(
s.is_class and ('Enum' in s.bases or
any(d in ('dataclass', 'dataclasses.dataclass') for d in s.decorators))
for s in member_syms
)
if all_data:
return _unique_name("models", used_names)
# 2. Dominant class (>50% of lines)
classes = [s for s in member_syms if s.is_class]
for cls in sorted(classes, key=lambda s: s.num_lines, reverse=True):
if cls.num_lines > total_lines * 0.5:
return _unique_name(_slugify(cls.name), used_names)
# 3. All constants
if all(s.is_constant for s in member_syms):
return _unique_name("constants", used_names)
# 4. Semantic root: split names into words, find shared words
candidate = _find_common_theme(member_syms)
if candidate:
return _unique_name(candidate, used_names)
# 5. Hub symbol: the one most-referenced by other members in this cluster
member_names = {s.name for s in member_syms}
ref_counts = defaultdict(int)
for s in member_syms:
for ref in s.references:
if ref in member_names:
ref_counts[ref] += 1
if ref_counts:
hub = max(ref_counts, key=ref_counts.get)
hub_sym = next(s for s in member_syms if s.name == hub)
if hub_sym.is_class:
return _unique_name(_slugify(hub), used_names)
# 6. Composition-based: name by what kinds of symbols are present
functions = [s for s in member_syms if s.is_function]
constants = [s for s in member_syms if s.is_constant]
classes = [s for s in member_syms if s.is_class]
# Single-symbol cluster
if len(member_syms) == 1:
s = member_syms[0]
slug = _slugify(s.name)
if s.is_class:
return _unique_name(slug, used_names)
# For a single function, "cli" if it's main/entry-point-like
if s.is_function and s.name in ('main', 'cli', 'run', 'app', 'entry'):
return _unique_name("cli", used_names)
if len(slug) <= 15:
return _unique_name(slug, used_names)
# Functions + constants mixed → "core" (they're the glue)
if functions and constants and not classes:
return _unique_name("core", used_names)
# Only functions
if functions and not constants and not classes:
# Try to find if there's a single entry point referencing others
entry = [f for f in functions if not any(
f.name in other.references for other in functions if other is not f)]
if len(entry) == 1:
slug = _slugify(entry[0].name)
if slug in ('main', 'cli', 'run', 'app'):
return _unique_name("cli", used_names)
if len(slug) <= 15:
return _unique_name(slug, used_names)
return _unique_name("utils", used_names)
# Only constants
if constants and not functions and not classes:
return _unique_name("constants", used_names)
return _unique_name("module", used_names)
def _find_common_theme(syms: list[SymbolInfo]) -> Optional[str]:
"""Find a common semantic theme across symbol names.
Splits each name into words (snake_case or CamelCase), finds words that
appear in multiple names. Returns the most common shared word as a
module name candidate.
"""
word_counts = defaultdict(int)
total = len(syms)
for s in syms:
# Split name into words: "handle_hook_event" → ["handle", "hook", "event"]
# Also handle CamelCase: "ConflictAnalysis" → ["conflict", "analysis"]
slug = _slugify(s.name)
words = set(slug.split('_'))
# Remove generic words that don't carry meaning
generic = {'get', 'set', 'make', 'create', 'build', 'do', 'run', 'is',
'has', 'can', 'the', 'a', 'an', 'to', 'from', 'for', 'with',
'handle', 'process', 'generate', 'print', 'show', 'display',
'check', 'find', 'search', 'update', 'delete', 'add', 'remove',
'init', 'setup', 'config', 'main', 'type', 'info', 'data',
'result', 'error', 'name', 'value', 'list', 'item', 'node'}
words -= generic
for w in words:
if len(w) > 2: # skip tiny fragments
word_counts[w] += 1
if not word_counts:
return None
# A word must appear in at least 40% of members to be a theme
threshold = max(2, total * 0.4)
candidates = [(word, count) for word, count in word_counts.items()
if count >= threshold]
if not candidates:
return None
# Return the most common word
best_word = max(candidates, key=lambda x: x[1])[0]
return best_word
def _unique_name(base: str, used: set[str]) -> str:
"""Return base if available, else base_2, base_3, etc."""
if base not in used:
return base
i = 2
while f"{base}_{i}" in used:
i += 1
return f"{base}_{i}"
def _slugify(name: str) -> str:
"""Convert CamelCase to snake_case for module names."""
s = re.sub(r'([A-Z]+)([A-Z][a-z])', r'\1_\2', name)
s = re.sub(r'([a-z\d])([A-Z])', r'\1_\2', s)
return s.lower()
# ── Dependency graph ────────────────────────────────────────────────────────
def build_module_deps(
symbols: list[SymbolInfo],
module_spec: dict[str, list[str]],
) -> dict[str, set[tuple[str, str]]]:
"""Build module-level dependency graph: module -> set of (dep_module, dep_symbol)."""
sym_to_mod = {}
for mod, sym_names in module_spec.items():
for sn in sym_names:
sym_to_mod[sn] = mod
sym_by_name = {s.name: s for s in symbols}
module_deps = defaultdict(set)
for mod, sym_names in module_spec.items():
for sn in sym_names:
s = sym_by_name.get(sn)
if not s:
continue
for ref in s.references:
ref_mod = sym_to_mod.get(ref)
if ref_mod and ref_mod != mod:
module_deps[mod].add((ref_mod, ref))
return dict(module_deps)
def detect_cycles(module_deps: dict[str, set[tuple[str, str]]]) -> list[list[str]]:
"""
Detect cycles using iterative DFS (Tarjan-style).
Returns list of cycles, where each cycle is a list of module names.
"""
edges = defaultdict(set)
for mod, deps in module_deps.items():
for dep_mod, _ in deps:
edges[mod].add(dep_mod)
all_nodes = set(edges.keys())
for targets in edges.values():
all_nodes |= targets
WHITE, GRAY, BLACK = 0, 1, 2
color = {n: WHITE for n in all_nodes}
parent = {n: None for n in all_nodes}
cycles = []
for start in all_nodes:
if color[start] != WHITE:
continue
stack = [(start, iter(edges.get(start, set())))]
color[start] = GRAY
while stack:
node, children = stack[-1]
try:
child = next(children)
if color[child] == GRAY:
# Found cycle — trace back
cycle = [child]
for frame_node, _ in reversed(stack):
cycle.append(frame_node)
if frame_node == child:
break
cycles.append(cycle)
elif color[child] == WHITE:
color[child] = GRAY
parent[child] = node
stack.append((child, iter(edges.get(child, set()))))
except StopIteration:
color[node] = BLACK
stack.pop()
return cycles
def topological_sort(module_deps: dict[str, set[tuple[str, str]]], all_modules: list[str]) -> list[str]:
"""Sort modules so dependencies come first."""
edges = defaultdict(set)
for mod, deps in module_deps.items():
for dep_mod, _ in deps:
edges[mod].add(dep_mod)
visited = set()
order = []
def visit(node):
if node in visited:
return
visited.add(node)
for dep in edges.get(node, set()):
visit(dep)
order.append(node)
for m in all_modules:
visit(m)
return order
# ── Import analysis ─────────────────────────────────────────────────────────
def filter_imports_for_symbols(
all_imports: list[ast.AST],
source: str,
sym_names: list[str],
symbols: list[SymbolInfo],
) -> list[str]:
"""Determine which top-level imports a set of symbols actually needs."""
sym_by_name = {s.name: s for s in symbols}
# Collect all names used in these symbols
used_names = set()
used_modules = set() # for `module.attr` patterns
for sn in sym_names:
s = sym_by_name.get(sn)
if not s:
continue
class UsageVisitor(ast.NodeVisitor):
def visit_Name(self, n):
used_names.add(n.id)
self.generic_visit(n)
def visit_Attribute(self, n):
if isinstance(n.value, ast.Name):
used_modules.add(n.value.id)
self.generic_visit(n)
UsageVisitor().visit(s.node)
all_used = used_names | used_modules
# Filter top-level imports
lines = source.splitlines()
needed = []
seen = set()
for imp_node in all_imports:
if isinstance(imp_node, ast.Import):
for alias in imp_node.names:
bound_name = alias.asname or alias.name
if bound_name in all_used or bound_name.split('.')[0] in all_used:
text = _extract_lines(lines, imp_node.lineno, imp_node.end_lineno)
if text not in seen:
needed.append(text)
seen.add(text)
break
elif isinstance(imp_node, ast.ImportFrom):
relevant = []
for alias in imp_node.names:
bound_name = alias.asname or alias.name
if bound_name in all_used:
relevant.append(alias)
if relevant:
if len(relevant) < len(imp_node.names):
# Reconstruct with only needed names
names_str = ', '.join(
f"{a.name} as {a.asname}" if a.asname else a.name
for a in relevant
)
text = f"from {imp_node.module} import {names_str}"
else:
text = _extract_lines(lines, imp_node.lineno, imp_node.end_lineno)
if text not in seen:
needed.append(text)
seen.add(text)
elif isinstance(imp_node, ast.Try):
# try: import X / except ImportError: ... blocks
# Check if any import inside the try body is used
any_used = False
for stmt in imp_node.body:
if isinstance(stmt, ast.Import):
for alias in stmt.names:
bound_name = alias.asname or alias.name
if bound_name in all_used or bound_name.split('.')[0] in all_used:
any_used = True
break
elif isinstance(stmt, ast.ImportFrom):
for alias in stmt.names:
bound_name = alias.asname or alias.name
if bound_name in all_used:
any_used = True
break
if any_used:
break
if any_used:
# Emit the entire try/except block as-is
text = _extract_lines(lines, imp_node.lineno, imp_node.end_lineno)
if text not in seen:
needed.append(text)
seen.add(text)
return needed
def _extract_lines(lines: list[str], start: int, end: int) -> str:
"""Extract source lines (1-based, inclusive)."""
return '\n'.join(lines[start - 1:end])
# ── Code generation ─────────────────────────────────────────────────────────
def generate_module(
mod_name: str,
sym_names: list[str],
source: str,
symbols: list[SymbolInfo],
all_imports: list[ast.AST],
module_deps: dict[str, set[tuple[str, str]]],
package_name: str,
) -> ModuleFile:
"""Generate a complete module file.
Hoists stdlib inline imports to module top-level and strips them from
function bodies. Non-stdlib inline imports are left in place.
"""
sym_by_name = {s.name: s for s in symbols}
lines = source.splitlines()
parts = []
# Collect stdlib inline imports to hoist (deduplicated)
hoisted_imports: list[str] = []
hoisted_seen: set[str] = set()
# Track which source lines to strip (1-based line numbers)
strip_lines: set[int] = set()
for sn in sym_names:
s = sym_by_name.get(sn)
if not s:
continue
for imp_text, lineno, end_lineno in s.inline_imports:
if is_stdlib_import(imp_text) and imp_text not in hoisted_seen:
hoisted_imports.append(imp_text)
hoisted_seen.add(imp_text)
if is_stdlib_import(imp_text):
for ln in range(lineno, end_lineno + 1):
strip_lines.add(ln)
# Header
parts.append(f'"""{package_name}.{mod_name} - Auto-split module"""')
parts.append('')
# Stdlib/third-party imports (from original top-level)
needed = filter_imports_for_symbols(all_imports, source, sym_names, symbols)
if needed:
parts.extend(needed)
# Hoisted stdlib imports from function bodies
if hoisted_imports:
for imp in hoisted_imports:
if imp not in needed:
parts.append(imp)
if needed or hoisted_imports:
parts.append('')
# Cross-module imports
deps = module_deps.get(mod_name, set())
if deps:
by_mod = defaultdict(list)
for dep_mod, dep_sym in deps:
by_mod[dep_mod].append(dep_sym)
for dep_mod in sorted(by_mod):
syms = sorted(by_mod[dep_mod])
parts.append(f"from .{dep_mod} import {', '.join(syms)}")
parts.append('')
# Symbol bodies (strip hoisted stdlib imports from source)
for sn in sym_names:
s = sym_by_name.get(sn)
if not s:
parts.append(f"# WARNING: {sn} not found in source")
parts.append('')
continue
# Include decorator lines (AST lineno starts at the `def`/`class`, not decorator)
actual_start = s.start_line
if hasattr(s.node, 'decorator_list') and s.node.decorator_list:
first_dec = s.node.decorator_list[0]
actual_start = first_dec.lineno
# Build source, stripping hoisted inline imports
sym_lines = []
for line_idx in range(actual_start - 1, s.end_line):
lineno_1based = line_idx + 1
if lineno_1based in strip_lines:
continue
sym_lines.append(lines[line_idx])
sym_source = '\n'.join(sym_lines)
parts.append('')
parts.append(sym_source)
parts.append('')
content = '\n'.join(parts)
return ModuleFile(
name=mod_name,
symbols=sym_names,
source=content,
line_count=len(content.splitlines()),
deps=deps,
)
def generate_init(
module_spec: dict[str, list[str]],
existing_init: Optional[str],
package_name: str,
) -> str:
"""Generate __init__.py preserving version and re-exporting all symbols."""
parts = []
parts.append(f'"""{package_name} - Auto-split package"""')
parts.append('')
# Preserve __version__ if it exists
if existing_init:
for line in existing_init.splitlines():
if line.strip().startswith('__version__'):
parts.append(line)
parts.append('')
break
# Re-export from each module
for mod_name, sym_names in module_spec.items():
sym_list = ', '.join(sym_names)
parts.append(f"from .{mod_name} import {sym_list}")
parts.append('')
# __all__
all_syms = []
for sym_names in module_spec.values():
all_syms.extend(sym_names)
# Include __version__ if found
if existing_init and '__version__' in existing_init:
all_syms.append('__version__')
parts.append(f"__all__ = {all_syms!r}")
parts.append('')
return '\n'.join(parts)
# ── Verification ────────────────────────────────────────────────────────────
def verify_split(
symbols: list[SymbolInfo],
module_spec: dict[str, list[str]],
module_deps: dict[str, set[tuple[str, str]]],
generated: dict[str, ModuleFile],
source: str,
) -> list[tuple[str, str]]:
"""
Run all verification checks. Returns list of (level, message) tuples.
level is one of: 'error', 'warn', 'ok'.
"""
results = []
sym_by_name = {s.name: s for s in symbols}
# 1. Completeness: every symbol assigned to exactly one module
assigned = {}
for mod, sym_names in module_spec.items():
for sn in sym_names:
if sn in assigned:
results.append(('error', f"DUPLICATE: '{sn}' in both '{assigned[sn]}' and '{mod}'"))
assigned[sn] = mod
for s in symbols:
if s.name not in assigned:
results.append(('error', f"UNASSIGNED: '{s.name}' (L{s.start_line}-{s.end_line}, {s.num_lines} lines)"))
for mod, sym_names in module_spec.items():
for sn in sym_names:
if sn not in sym_by_name:
results.append(('error', f"MISSING: '{sn}' in spec for '{mod}' but not in source"))
if not any(r[0] == 'error' for r in results if 'UNASSIGNED' in r[1] or 'DUPLICATE' in r[1] or 'MISSING' in r[1]):
results.append(('ok', "All symbols assigned to exactly one module"))
# 2. Circular dependencies
cycles = detect_cycles(module_deps)
if cycles:
for cycle in cycles:
results.append(('error', f"CYCLE: {' -> '.join(cycle)}"))
else:
results.append(('ok', "No circular dependencies"))
# 3. Cross-module import targets exist
for mod in module_spec:
for dep_mod, dep_sym in module_deps.get(mod, set()):
if dep_sym in module_spec.get(dep_mod, []):
results.append(('ok', f"{mod} imports {dep_sym} from {dep_mod}"))
else:
results.append(('error', f"{mod} needs {dep_sym} from {dep_mod} — NOT IN THAT MODULE"))
# 4. AST isomorphism check: re-parse each generated module, find each symbol,
# hash it, and compare against the original's post-hoist hash.
# If hashes match, the split is structurally identical (same AST minus position).
for mod_name, mod_file in generated.items():
if mod_name == '__init__':
continue
try:
split_tree = ast.parse(mod_file.source)
except SyntaxError as e:
results.append(('error', f"[{mod_name}] SYNTAX ERROR in generated module: {e}"))
continue
# Index split symbols by name
split_syms = {}
for split_node in ast.iter_child_nodes(split_tree):
if isinstance(split_node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
split_syms[split_node.name] = split_node
elif isinstance(split_node, ast.Assign):
for t in split_node.targets:
if isinstance(t, ast.Name):
split_syms[t.id] = split_node
break
for sn in mod_file.symbols:
s = sym_by_name.get(sn)
if not s:
continue
split_node = split_syms.get(sn)
if not split_node:
results.append(('error', f"[{mod_name}] {sn} NOT FOUND in generated module AST"))
continue
# Hash the split version (no strip needed — stdlib imports already removed)
split_hash = compute_structural_hash(mod_file.source, split_node)
original_posthoist = s.structural_hash_posthoist
if split_hash == original_posthoist:
if s.structural_hash != original_posthoist:
results.append(('ok',
f"[{mod_name}] {sn} AST isomorphic ✓ "
f"(original: {s.structural_hash} → hoisted: {original_posthoist} == split: {split_hash})"))
else:
results.append(('ok',
f"[{mod_name}] {sn} AST isomorphic ✓ (hash: {split_hash})"))
else:
results.append(('error',
f"[{mod_name}] {sn} AST MISMATCH! "
f"original(posthoist): {original_posthoist} != split: {split_hash}"))
# 5. Inline import hoisting report
for s in symbols:
if s.inline_imports:
mod = assigned.get(s.name, '?')
for imp_text, lineno, end_lineno in s.inline_imports:
if is_stdlib_import(imp_text):
results.append(('ok', f"[{mod}] {s.name} L{lineno}: hoisted stdlib '{imp_text}' to top-level"))
else:
results.append(('ok', f"[{mod}] {s.name} L{lineno}: non-stdlib inline import kept in body: {imp_text}"))
return results
# ── Output formatting ───────────────────────────────────────────────────────
def print_analysis(
symbols: list[SymbolInfo],
module_spec: dict[str, list[str]],
module_deps: dict[str, set[tuple[str, str]]],
generated: dict[str, ModuleFile],
source: str,
source_path: Path,
):
"""Print the full analysis report."""
sym_by_name = {s.name: s for s in symbols}
print("=" * 70)
print("MODULE SPLITTER - AST Analysis Report")
print(f"Source: {source_path}")
print(f"Symbols: {len(symbols)} | Modules: {len(module_spec)}")
print("=" * 70)
# Symbol table
print("\n── Symbol Table ──")
print(f" {'Name':<35} {'Kind':<10} {'Lines':>8} {'Hash':>14} {'Module':<12}")
print(f" {'─' * 35} {'─' * 10} {'─' * 8} {'─' * 14} {'─' * 12}")
assigned = {}
for mod, syms in module_spec.items():
for sn in syms:
assigned[sn] = mod
for s in symbols:
mod = assigned.get(s.name, '???')
line_range = f"L{s.start_line}-{s.end_line}"
print(f" {s.name:<35} {s.kind:<10} {line_range:>8} {s.structural_hash:>14} {mod:<12}")
if s.signature:
sig = s.signature
if len(sig) > 65:
sig = sig[:62] + "..."
print(f" {sig}")
if s.decorators:
print(f" decorators: {', '.join(s.decorators)}")
if s.bases:
print(f" bases: {', '.join(s.bases)}")
# Dependencies
print("\n── Module Dependencies ──")
topo = topological_sort(module_deps, list(module_spec.keys()))
for mod in topo:
deps = module_deps.get(mod, set())
if deps:
by_mod = defaultdict(list)
for d_mod, d_sym in deps:
by_mod[d_mod].append(d_sym)
dep_strs = [f"{m}({', '.join(sorted(s))})" for m, s in sorted(by_mod.items())]
print(f" {mod} → {', '.join(dep_strs)}")
else:
print(f" {mod} → (leaf)")
print(f"\n Topological order: {' → '.join(topo)}")
# Generated files
print("\n── Generated Files ──")
total = 0
for mod_name in list(module_spec.keys()) + ['__init__']:
mf = generated.get(mod_name)
if mf:
total += mf.line_count
print(f" {mf.name + '.py':<20} {mf.line_count:>6} lines [{', '.join(mf.symbols[:4])}{'...' if len(mf.symbols) > 4 else ''}]")
elif mod_name == '__init__':
init_src = generated.get('__init__')
if init_src:
lc = init_src.line_count if isinstance(init_src, ModuleFile) else len(init_src.splitlines()) if isinstance(init_src, str) else 0
total += lc
print(f" __init__.py {lc:>6} lines [re-exports]")
original_lines = len(source.splitlines())
print(f"\n Original: {original_lines:,} lines (1 file)")
print(f" Split: {total:,} lines ({len(generated)} files)")
print(f" Overhead: {total - original_lines:+,} lines")
def print_verification(results: list[tuple[str, str]]):
"""Print verification results."""
print("\n── Verification ──")
icons = {'ok': '✅', 'warn': '⚠️ ', 'error': '❌'}
errors = [r for r in results if r[0] == 'error']
warns = [r for r in results if r[0] == 'warn']
oks = [r for r in results if r[0] == 'ok']
# Show errors and warnings always, oks only in summary
for level, msg in errors:
print(f" {icons[level]} {msg}")
for level, msg in warns:
print(f" {icons[level]} {msg}")
print(f"\n Summary: {len(oks)} ok, {len(warns)} warnings, {len(errors)} errors")
if not errors:
print(" ✅ Split is safe to apply")
else:
print(" ❌ Fix errors before applying")
# ── Main ────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(
description="General-purpose AST-based Python module splitter",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument('source', type=Path, help='Python source file to split')
parser.add_argument('--auto', action='store_true', help='Auto-discover module clusters')
parser.add_argument('--spec', type=Path, help='JSON file with module spec: {"mod": ["sym1", ...]}')
parser.add_argument('--write', action='store_true', help='Actually write the split files')
parser.add_argument('--verify', action='store_true', help='Run verification checks')
parser.add_argument('--outdir', type=Path, default=None, help='Output directory (default: same as source)')
parser.add_argument('--package', type=str, default=None, help='Package name (default: parent dir name)')
parser.add_argument('--symbols', action='store_true', help='Only print symbol table, no split')
parser.add_argument('--deps', action='store_true', help='Only print dependency graph')
parser.add_argument('--hashes', action='store_true', help='Print structural hashes for drift detection')
parser.add_argument('--isomorphic', action='store_true',
help='End-to-end AST isomorphism test: write split to tmpdir, '
're-parse, hash every symbol independently, compare to original')
args = parser.parse_args()
source_path = args.source
if not source_path.exists():
print(f"ERROR: {source_path} not found", file=sys.stderr)
sys.exit(1)
outdir = args.outdir or source_path.parent
package_name = args.package or outdir.name
# Parse
source, tree = parse_source(source_path)
symbols, top_level_imports = analyze_file(source, tree)
# Symbols-only mode
if args.symbols:
print(f"{'Name':<35} {'Kind':<10} {'Lines':>10} {'Hash':>14} {'Refs'}")
print(f"{'─' * 35} {'─' * 10} {'─' * 10} {'─' * 14} {'─' * 30}")
for s in symbols:
refs = ', '.join(sorted(s.references)) if s.references else '(none)'
lr = f"L{s.start_line}-{s.end_line}"
print(f"{s.name:<35} {s.kind:<10} {lr:>10} {s.structural_hash:>14} {refs}")
if s.signature:
print(f" {s.signature}")
if s.inline_imports:
for imp_text, lineno, end_lineno in s.inline_imports:
stdlib_tag = "stdlib→hoist" if is_stdlib_import(imp_text) else "non-stdlib"
print(f" [inline L{lineno}] ({stdlib_tag}) {imp_text}")
return
# Deps-only mode
if args.deps:
print("Dependency graph (symbol → references):")
for s in symbols:
refs = ', '.join(sorted(s.references)) if s.references else '(none)'
print(f" {s.name} → {refs}")
return
# Hashes-only mode
if args.hashes:
print(f"Structural hashes for {source_path}:")
for s in symbols:
if s.structural_hash != s.structural_hash_posthoist:
print(f" {s.structural_hash} {s.name} (L{s.start_line}-{s.end_line})")
print(f" {s.structural_hash_posthoist} └─ post-hoist (stdlib imports stripped)")
else:
print(f" {s.structural_hash} {s.name} (L{s.start_line}-{s.end_line})")
return
# Determine module spec
if args.spec:
with open(args.spec) as f:
module_spec = json.load(f)
elif args.auto:
module_spec = auto_cluster(symbols)
else:
# Default: use niwa-specific spec
module_spec = {
"models": ["ConflictType", "Edit", "ConflictAnalysis", "EditResult"],
"prompts": ["LLM_SYSTEM_PROMPT", "COMMAND_HELP", "ERROR_PROMPTS"],
"db": ["Niwa"],
"hooks": ["generate_claude_hooks_config", "get_niwa_usage_guide",
"handle_hook_event", "setup_claude_hooks"],
"cli": ["print_error", "print_command_help", "main"],
}
# Build dependency graph
module_deps = build_module_deps(symbols, module_spec)
# Generate files
generated = {}
for mod_name, sym_names in module_spec.items():
mf = generate_module(
mod_name, sym_names, source, symbols,
top_level_imports, module_deps, package_name,
)
generated[mod_name] = mf
# Generate __init__.py
existing_init = None
init_path = outdir / "__init__.py"
if init_path.exists():
existing_init = init_path.read_text()
init_source = generate_init(module_spec, existing_init, package_name)
generated['__init__'] = ModuleFile(
name='__init__',
symbols=[],
source=init_source,
line_count=len(init_source.splitlines()),
)
# Print analysis
print_analysis(symbols, module_spec, module_deps, generated, source, source_path)
# Verification
if args.verify or args.write:
results = verify_split(symbols, module_spec, module_deps, generated, source)
print_verification(results)
if args.write and any(r[0] == 'error' for r in results):
print("\n❌ Cannot write: fix errors first")
sys.exit(1)
# Isomorphic test: independent end-to-end AST hash comparison
if args.isomorphic:
import tempfile
import shutil
print("\n── Isomorphic AST Test ──")
print(" Writing split to temp dir, re-parsing, hashing independently...\n")
tmpdir = Path(tempfile.mkdtemp(prefix='split_iso_'))
try:
# Write generated files to tmpdir
for mod_name, mf in generated.items():
if mod_name == '__init__':
path = tmpdir / "__init__.py"
else:
path = tmpdir / f"{mod_name}.py"
path.write_text(mf.source if isinstance(mf, ModuleFile) else mf)
# Build expected hashes from original (post-hoist)
expected_hashes = {}
for s in symbols:
expected_hashes[s.name] = s.structural_hash_posthoist
# Re-parse each split file and hash symbols independently
ok_count = 0
fail_count = 0
for py_file in sorted(tmpdir.glob('*.py')):
if py_file.name == '__init__.py':
continue
split_src = py_file.read_text()
try:
split_tree = ast.parse(split_src)
except SyntaxError as e:
print(f" ✗ {py_file.name}: SYNTAX ERROR: {e}")
fail_count += 1
continue
for split_node in ast.iter_child_nodes(split_tree):
name = None
if isinstance(split_node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
name = split_node.name
elif isinstance(split_node, ast.Assign):
for t in split_node.targets:
if isinstance(t, ast.Name):
name = t.id
break
if name and name in expected_hashes:
split_hash = compute_structural_hash(split_src, split_node)
expected = expected_hashes[name]
if split_hash == expected:
ok_count += 1
print(f" ✓ {name}: {split_hash}")
else:
fail_count += 1
print(f" ✗ {name}: split={split_hash} != expected={expected}")
# Check coverage: every original symbol should appear in some split file
found_in_split = set()
for py_file in sorted(tmpdir.glob('*.py')):
if py_file.name == '__init__.py':
continue
split_tree = ast.parse(py_file.read_text())
for split_node in ast.iter_child_nodes(split_tree):
if isinstance(split_node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
found_in_split.add(split_node.name)
elif isinstance(split_node, ast.Assign):
for t in split_node.targets:
if isinstance(t, ast.Name):
found_in_split.add(t.id)
break
missing = set(expected_hashes.keys()) - found_in_split
if missing:
for name in sorted(missing):
print(f" ✗ {name}: NOT FOUND in any split file")
fail_count += 1
print(f"\n Result: {ok_count} match, {fail_count} mismatch, "
f"{len(expected_hashes)} symbols total")
if fail_count == 0:
print(" ✅ Split is isomorphic — identical AST structure proven via SHA-256")
else:
print(" ❌ AST isomorphism FAILED — split diverges from original")
sys.exit(1)
finally:
shutil.rmtree(tmpdir)
if not args.write:
return
# Write or dry-run
if args.write:
print("\n── Writing Files ──")
outdir.mkdir(parents=True, exist_ok=True)
# Backup original BEFORE writing (source may be inside outdir)
backup = source_path.with_suffix('.py.bak')
if source_path.exists() and outdir.resolve() == source_path.resolve().parent:
source_path.rename(backup)
print(f" 📦 Backed up original to {backup}")
for mod_name, mf in generated.items():
if mod_name == '__init__':
path = outdir / "__init__.py"
else:
path = outdir / f"{mod_name}.py"
path.write_text(mf.source if isinstance(mf, ModuleFile) else mf)
print(f" ✅ {path} ({mf.line_count} lines)")
else:
# Dry run — print generated files
for mod_name in list(module_spec.keys()) + ['__init__']:
mf = generated.get(mod_name)
if not mf:
continue
fname = f"{outdir}/{mod_name}.py" if mod_name != '__init__' else f"{outdir}/__init__.py"
print(f"\n{'═' * 70}")
print(f"FILE: {fname}")
print(f"{'═' * 70}")
src = mf.source if isinstance(mf, ModuleFile) else mf
print(src)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment