Last active
October 25, 2022 13:01
-
-
Save geblanco/83fb55279cab43a43d981efb90c91460 to your computer and use it in GitHub Desktop.
Naive __init__.py generator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Generate __init__.py for a given folder | |
Very naive generator, just parses the AST of every subfile in the provided folder to get the list of exported variables. | |
ToDo := | |
- Recursively parse subfolders and previous __init__.py files | |
- Improve pkg base adding | |
""" | |
import os | |
import ast | |
import argparse | |
from typing import List, Optional | |
def parse_flags() -> argparse.Namespace: | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
dest="data_dir", | |
type=str, | |
help="Directory containing .py files to get compose init from", | |
) | |
parser.add_argument( | |
"-o", | |
"--output", | |
type=str, | |
required=False, | |
default=None, | |
help="Where to write output code (default to stdout)", | |
) | |
parser.add_argument( | |
"--pkg_base", | |
type=str, | |
required=False, | |
help="Package base to add to every import", | |
) | |
parser.add_argument( | |
"--pkg_excludes", nargs="*", help="Files to avoid importing variables from" | |
) | |
parser.add_argument( | |
"--var_excludes", nargs="*", help="Exclude the given list of exported variables" | |
) | |
return parser.parse_args() | |
def get_files(data_dir: str, pkg_excludes: Optional[List[str]] = None) -> List[str]: | |
skiplist = ["__init__.py"] + (pkg_excludes or []) | |
files = os.listdir(data_dir) | |
filtered = filter(lambda fd: os.path.isfile(os.path.join(data_dir, fd)), files) | |
filtered = [file for file in filtered if file not in skiplist] | |
return filtered | |
def get_exports(file: str, var_excludes: Optional[List[str]] = None) -> List[str]: | |
parsed = ast.parse(open(file, "r").read()) | |
top_exported = set() | |
for node in parsed.body: | |
if isinstance(node, (ast.FunctionDef, ast.ClassDef)): | |
top_exported.add(node.name) | |
elif isinstance(node, ast.Assign): | |
top_exported.add(node.targets[0].id) | |
var_excludes = var_excludes or [] | |
var_excludes = [var.lower() for var in var_excludes] | |
return [exp for exp in sorted(top_exported) if exp.lower() not in var_excludes] | |
def format_import_line(line: str, sep=" ") -> str: | |
if sep is None: | |
sep = "" | |
return f"{sep}{line}" | |
def format_file_line(file: str) -> str: | |
return os.path.splitext(file)[0] | |
def format_import(file: str, pkgs: List[str], pkg_base: Optional[str] = None): | |
pkg_base = pkg_base or "" | |
file_import = format_file_line(file) | |
from_line = f"from {pkg_base}.{file_import} import" | |
pkg_lines = [format_import_line(pkg, sep=" ") for pkg in pkgs] | |
import_str = from_line + " ( # noqa: F401\n" + (",\n".join(pkg_lines)) + "\n)" | |
return {file_import: import_str} | |
def main( | |
data_dir: str, | |
output: Optional[str] = None, | |
pkg_base: Optional[str] = None, | |
pkg_excludes: Optional[List[str]] = None, | |
var_excludes: Optional[List[str]] = None, | |
): | |
imports = {} | |
files = get_files(data_dir, pkg_excludes) | |
for file in files: | |
exported = get_exports(os.path.join(data_dir, file), var_excludes) | |
if len(exported): | |
imports.update(format_import(file, exported, pkg_base)) | |
exports_str = "\n\n".join([imports[key] for key in sorted(imports)]) | |
if output is None: | |
print(exports_str) | |
else: | |
with open(output, "w") as fout: | |
fout.write(exports_str) | |
if __name__ == "__main__": | |
main(**vars(parse_flags())) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment