Created
August 22, 2024 06:52
-
-
Save EmbraceLife/731c0162a924e371a567e83d242d7fa2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ast | |
import re | |
from pathlib import Path | |
from inspect import signature | |
from fastcore.utils import * | |
from fastcore.meta import delegates | |
def _get_tree(mod): | |
return parse(getsource(mod)) | |
@patch | |
def __repr__(self:ast.AST): | |
return unparse(self) | |
@patch | |
def _repr_markdown_(self:ast.AST): | |
return f"""```python | |
{self!r} | |
```""" | |
functypes = (ast.FunctionDef,ast.AsyncFunctionDef) | |
def _deco_id(d:Union[ast.Name,ast.Attribute])->bool: | |
return d.id if isinstance(d, ast.Name) else d.func.id | |
def has_deco(node:Union[ast.FunctionDef,ast.AsyncFunctionDef], name:str)->bool: | |
return any(_deco_id(d)==name for d in getattr(node, 'decorator_list', [])) | |
def _get_proc(node): | |
if isinstance(node, ast.ClassDef): return _proc_class | |
if not isinstance(node, functypes): return None | |
if not has_deco(node, 'delegates'): return _proc_body | |
if has_deco(node, 'patch'): return _proc_patched | |
return _proc_func | |
def _proc_tree(tree, mod, local_types): | |
for node in tree.body: | |
proc = _get_proc(node) | |
if proc: proc(node, mod, local_types) | |
def _proc_mod(mod, local_types): | |
tree = _get_tree(mod) | |
_proc_tree(tree, mod, local_types) | |
tree.body = [node for node in tree.body | |
if not (isinstance(node, ast.FunctionDef) and has_deco(node, 'patch'))] | |
return tree | |
def sig2str(sig, local_types): | |
s = str(sig) | |
s = re.sub(r"<class '(.*?)'>", r'\1', s) | |
# Only remove module prefixes for types in local_types | |
for type_name in local_types: | |
s = re.sub(rf"(\w+\.)+{type_name}\b", type_name, s) | |
s = re.sub(r"dynamic_module\.", "", s) | |
return s | |
def ast_args(func, local_types): | |
sig = signature(func) | |
def clean_annotation(param_name, anno): | |
return anno # Return the annotation as-is for now | |
cleaned_params = [ | |
param.replace(annotation=clean_annotation(param.name, param.annotation)) | |
for param in sig.parameters.values() | |
] | |
cleaned_sig = sig.replace(parameters=cleaned_params) | |
return ast.parse(f"def _{sig2str(cleaned_sig, local_types)}: ...").body[0].args | |
def _body_ellip(n: ast.AST): | |
has_ellipsis = any( | |
isinstance(node, ast.Expr) and | |
isinstance(node.value, ast.Constant) and | |
node.value.value is Ellipsis | |
for node in n.body | |
) | |
if not has_ellipsis: | |
stidx = 1 if n.body and isinstance(n.body[0], ast.Expr) and isinstance(n.body[0].value, ast.Constant) else 0 | |
n.body[stidx:] = [ast.Expr(ast.Constant(...))] | |
def _update_func(node, sym, local_types): | |
node.args = ast_args(sym, local_types) | |
_body_ellip(node) | |
node.decorator_list = [d for d in node.decorator_list if _deco_id(d) != 'delegates'] | |
def _proc_body(node, mod, local_types): | |
_body_ellip(node) | |
def _proc_func(node, mod, local_types): | |
sym = getattr(mod, node.name) | |
_update_func(node, sym, local_types) | |
def _proc_patched(node, mod, local_types): | |
pass | |
def _proc_class(node, mod, local_types): | |
cls = getattr(mod, node.name) | |
for name, member in cls.__dict__.items(): | |
if callable(member) and not name.startswith('__'): | |
if name not in [n.name for n in node.body if isinstance(n, ast.FunctionDef)]: | |
print(name, member) | |
_add_method_to_class(node, name, member, local_types) | |
_proc_tree(node, cls, local_types) | |
def _add_method_to_class(class_node, method_name, method_object, local_types): | |
method_ast = ast.FunctionDef( | |
name=method_name, | |
args=ast_args(method_object, local_types), | |
body=[], | |
decorator_list=[], | |
lineno=1, | |
col_offset=0, | |
end_lineno=1, | |
end_col_offset=0, | |
) | |
if method_ast.args.args and method_ast.args.args[0].arg == 'self': | |
method_ast.args.args[0].annotation = None | |
doc = method_object.__doc__ | |
if doc: | |
docstring = ast.Expr(value=ast.Constant(value=doc, kind=None)) | |
method_ast.body.append(docstring) | |
_body_ellip(method_ast) | |
class_node.body = [node for node in class_node.body | |
if not (isinstance(node, ast.Expr) and | |
isinstance(node.value, ast.Constant) and | |
node.value.value is Ellipsis)] | |
class_node.body.append(method_ast) | |
def create_pyi(fn, package=None): | |
fn = Path(fn) | |
with open(fn, 'r') as file: | |
tree = ast.parse(file.read()) | |
local_types = {node.name for node in ast.walk(tree) if isinstance(node, ast.ClassDef)} | |
mod = imp_mod(fn, package=package) | |
tree = _proc_mod(mod, local_types) | |
res = unparse(tree) | |
fn.with_suffix('.pyi').write_text(res) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment