Skip to content

Instantly share code, notes, and snippets.

@milkey-mouse
Last active October 13, 2024 03:33
Show Gist options
  • Save milkey-mouse/400eb1da0d99ec86752b9fe18b5d8eba to your computer and use it in GitHub Desktop.
Save milkey-mouse/400eb1da0d99ec86752b9fe18b5d8eba to your computer and use it in GitHub Desktop.
import ast
from collections import defaultdict
from dataclasses import dataclass
from textwrap import dedent
@dataclass
class FunctionSpan:
indent: int
definition_line: int
body_line: int
end_line: int
class ExtractFunctions(ast.NodeVisitor):
def __init__(self):
self.prefix = []
self.spans = {}
super().__init__()
def visit_FunctionDef(self, f):
self.prefix.append(f.name)
if (
f.body
and isinstance(f.body[0], ast.Expr)
and isinstance(f.body[0].value, ast.Str)
):
# function has a docstring; body starts after it
body_line = f.body[0].end_lineno
else:
body_line = f.lineno
self.spans[".".join(self.prefix)] = FunctionSpan(
indent=f.col_offset,
definition_line=f.lineno - 1,
body_line=body_line,
end_line=f.end_lineno,
)
self.generic_visit(f)
self.prefix.pop()
def extract_functions(source, parsed=None):
e = ExtractFunctions()
e.visit(parsed or ast.parse(source))
functions = {}
lines = source.splitlines()
lines_deleted_at = float("inf")
lines_deleted = 0
for name, span in reversed(e.spans.items()):
if span.definition_line > lines_deleted_at:
span.definition_line -= lines_deleted
if span.body_line > lines_deleted_at:
span.body_line -= lines_deleted
if span.end_line > lines_deleted_at:
span.end_line -= lines_deleted
# functions[name] = "\n".join(
# # TODO: what if a docstring's indent goes below the function's?
# l[span.indent:] for l in lines[span.definition_line:span.end_line]
# )
functions[name] = dedent("\n".join(lines[span.definition_line : span.end_line]))
lines[span.body_line : span.end_line] = (
" " * span.indent + " # omitted",
" " * span.indent + " pass",
)
lines_deleted_at = span.body_line
lines_deleted += (span.end_line - span.body_line) - 2
return "\n".join(lines), functions
@dataclass
class Scope:
name: str
vars: dict
class ExtractDependencies(ast.NodeVisitor):
def __init__(self):
self.dependencies = defaultdict(set)
self.scopes = []
super().__init__()
def resolve(self, name):
for scope in reversed(self.scopes):
if name in scope.vars:
return scope.vars[name]
return None
def qualified_name(self, name):
return (*(scope.name for scope in self.scopes), name)
def visit_FunctionDef(self, f):
self.scopes.append(Scope(f.name, {f.name: self.qualified_name(f.name)}))
self.generic_visit(f)
self.scopes.pop()
def visit_Name(self, name):
# NOTE: the fact that we aren't tracking var
# shadowing makes this analysis conservative
called = self.resolve(name.id)
if called:
inner_scope = self.scopes[-1]
qualified_caller = inner_scope[inner_scope.name]
self.dependencies[qualified_caller].add(called)
def extract_dependencies(source, parsed=None):
e = ExtractDependencies()
e.visit(parsed or ast.parse(source))
return e.dependencies
@milkey-mouse
Copy link
Author

License: CC0-1.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment