Skip to content

Instantly share code, notes, and snippets.

@EmbraceLife
Created August 22, 2024 06:52
Show Gist options
  • Save EmbraceLife/731c0162a924e371a567e83d242d7fa2 to your computer and use it in GitHub Desktop.
Save EmbraceLife/731c0162a924e371a567e83d242d7fa2 to your computer and use it in GitHub Desktop.
import ast
import re
from pathlib import Path
from inspect import signature
from fastcore.utils import *
from fastcore.meta import delegates
def _get_tree(mod):
return parse(getsource(mod))
@patch
def __repr__(self:ast.AST):
return unparse(self)
@patch
def _repr_markdown_(self:ast.AST):
return f"""```python
{self!r}
```"""
functypes = (ast.FunctionDef,ast.AsyncFunctionDef)
def _deco_id(d:Union[ast.Name,ast.Attribute])->bool:
return d.id if isinstance(d, ast.Name) else d.func.id
def has_deco(node:Union[ast.FunctionDef,ast.AsyncFunctionDef], name:str)->bool:
return any(_deco_id(d)==name for d in getattr(node, 'decorator_list', []))
def _get_proc(node):
if isinstance(node, ast.ClassDef): return _proc_class
if not isinstance(node, functypes): return None
if not has_deco(node, 'delegates'): return _proc_body
if has_deco(node, 'patch'): return _proc_patched
return _proc_func
def _proc_tree(tree, mod, local_types):
for node in tree.body:
proc = _get_proc(node)
if proc: proc(node, mod, local_types)
def _proc_mod(mod, local_types):
tree = _get_tree(mod)
_proc_tree(tree, mod, local_types)
tree.body = [node for node in tree.body
if not (isinstance(node, ast.FunctionDef) and has_deco(node, 'patch'))]
return tree
def sig2str(sig, local_types):
s = str(sig)
s = re.sub(r"<class '(.*?)'>", r'\1', s)
# Only remove module prefixes for types in local_types
for type_name in local_types:
s = re.sub(rf"(\w+\.)+{type_name}\b", type_name, s)
s = re.sub(r"dynamic_module\.", "", s)
return s
def ast_args(func, local_types):
sig = signature(func)
def clean_annotation(param_name, anno):
return anno # Return the annotation as-is for now
cleaned_params = [
param.replace(annotation=clean_annotation(param.name, param.annotation))
for param in sig.parameters.values()
]
cleaned_sig = sig.replace(parameters=cleaned_params)
return ast.parse(f"def _{sig2str(cleaned_sig, local_types)}: ...").body[0].args
def _body_ellip(n: ast.AST):
has_ellipsis = any(
isinstance(node, ast.Expr) and
isinstance(node.value, ast.Constant) and
node.value.value is Ellipsis
for node in n.body
)
if not has_ellipsis:
stidx = 1 if n.body and isinstance(n.body[0], ast.Expr) and isinstance(n.body[0].value, ast.Constant) else 0
n.body[stidx:] = [ast.Expr(ast.Constant(...))]
def _update_func(node, sym, local_types):
node.args = ast_args(sym, local_types)
_body_ellip(node)
node.decorator_list = [d for d in node.decorator_list if _deco_id(d) != 'delegates']
def _proc_body(node, mod, local_types):
_body_ellip(node)
def _proc_func(node, mod, local_types):
sym = getattr(mod, node.name)
_update_func(node, sym, local_types)
def _proc_patched(node, mod, local_types):
pass
def _proc_class(node, mod, local_types):
cls = getattr(mod, node.name)
for name, member in cls.__dict__.items():
if callable(member) and not name.startswith('__'):
if name not in [n.name for n in node.body if isinstance(n, ast.FunctionDef)]:
print(name, member)
_add_method_to_class(node, name, member, local_types)
_proc_tree(node, cls, local_types)
def _add_method_to_class(class_node, method_name, method_object, local_types):
method_ast = ast.FunctionDef(
name=method_name,
args=ast_args(method_object, local_types),
body=[],
decorator_list=[],
lineno=1,
col_offset=0,
end_lineno=1,
end_col_offset=0,
)
if method_ast.args.args and method_ast.args.args[0].arg == 'self':
method_ast.args.args[0].annotation = None
doc = method_object.__doc__
if doc:
docstring = ast.Expr(value=ast.Constant(value=doc, kind=None))
method_ast.body.append(docstring)
_body_ellip(method_ast)
class_node.body = [node for node in class_node.body
if not (isinstance(node, ast.Expr) and
isinstance(node.value, ast.Constant) and
node.value.value is Ellipsis)]
class_node.body.append(method_ast)
def create_pyi(fn, package=None):
fn = Path(fn)
with open(fn, 'r') as file:
tree = ast.parse(file.read())
local_types = {node.name for node in ast.walk(tree) if isinstance(node, ast.ClassDef)}
mod = imp_mod(fn, package=package)
tree = _proc_mod(mod, local_types)
res = unparse(tree)
fn.with_suffix('.pyi').write_text(res)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment