Skip to content

Instantly share code, notes, and snippets.

@geblanco
Last active October 25, 2022 13:01
Show Gist options
  • Save geblanco/83fb55279cab43a43d981efb90c91460 to your computer and use it in GitHub Desktop.
Save geblanco/83fb55279cab43a43d981efb90c91460 to your computer and use it in GitHub Desktop.
Naive __init__.py generator
"""Generate __init__.py for a given folder
Very naive generator, just parses the AST of every subfile in the provided folder to get the list of exported variables.
ToDo :=
- Recursively parse subfolders and previous __init__.py files
- Improve pkg base adding
"""
import os
import ast
import argparse
from typing import List, Optional
def parse_flags() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
dest="data_dir",
type=str,
help="Directory containing .py files to get compose init from",
)
parser.add_argument(
"-o",
"--output",
type=str,
required=False,
default=None,
help="Where to write output code (default to stdout)",
)
parser.add_argument(
"--pkg_base",
type=str,
required=False,
help="Package base to add to every import",
)
parser.add_argument(
"--pkg_excludes", nargs="*", help="Files to avoid importing variables from"
)
parser.add_argument(
"--var_excludes", nargs="*", help="Exclude the given list of exported variables"
)
return parser.parse_args()
def get_files(data_dir: str, pkg_excludes: Optional[List[str]] = None) -> List[str]:
skiplist = ["__init__.py"] + (pkg_excludes or [])
files = os.listdir(data_dir)
filtered = filter(lambda fd: os.path.isfile(os.path.join(data_dir, fd)), files)
filtered = [file for file in filtered if file not in skiplist]
return filtered
def get_exports(file: str, var_excludes: Optional[List[str]] = None) -> List[str]:
parsed = ast.parse(open(file, "r").read())
top_exported = set()
for node in parsed.body:
if isinstance(node, (ast.FunctionDef, ast.ClassDef)):
top_exported.add(node.name)
elif isinstance(node, ast.Assign):
top_exported.add(node.targets[0].id)
var_excludes = var_excludes or []
var_excludes = [var.lower() for var in var_excludes]
return [exp for exp in sorted(top_exported) if exp.lower() not in var_excludes]
def format_import_line(line: str, sep=" ") -> str:
if sep is None:
sep = ""
return f"{sep}{line}"
def format_file_line(file: str) -> str:
return os.path.splitext(file)[0]
def format_import(file: str, pkgs: List[str], pkg_base: Optional[str] = None):
pkg_base = pkg_base or ""
file_import = format_file_line(file)
from_line = f"from {pkg_base}.{file_import} import"
pkg_lines = [format_import_line(pkg, sep=" ") for pkg in pkgs]
import_str = from_line + " ( # noqa: F401\n" + (",\n".join(pkg_lines)) + "\n)"
return {file_import: import_str}
def main(
data_dir: str,
output: Optional[str] = None,
pkg_base: Optional[str] = None,
pkg_excludes: Optional[List[str]] = None,
var_excludes: Optional[List[str]] = None,
):
imports = {}
files = get_files(data_dir, pkg_excludes)
for file in files:
exported = get_exports(os.path.join(data_dir, file), var_excludes)
if len(exported):
imports.update(format_import(file, exported, pkg_base))
exports_str = "\n\n".join([imports[key] for key in sorted(imports)])
if output is None:
print(exports_str)
else:
with open(output, "w") as fout:
fout.write(exports_str)
if __name__ == "__main__":
main(**vars(parse_flags()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment