Last active
July 8, 2022 08:35
-
-
Save internetimagery/e56e60bb59ba6a0aa8654919122bebee to your computer and use it in GitHub Desktop.
Localize imports in batch. So modules are imported as they are needed. Useful when trying to improve import times.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Permission to use, copy, modify, and/or distribute this software for any purpose with or without | |
# fee is hereby granted. | |
# | |
# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO | |
# THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE | |
# AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER | |
# RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE | |
# OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. | |
# Example: | |
# python /path/to/local_import.py --write /path/to/source.py | |
import os | |
import ast | |
import token | |
import asttokens | |
from fnmatch import fnmatch | |
from functools import partial | |
from contextlib import closing | |
from multiprocessing import Pool | |
def main(files, include="", exclude="", write=False): | |
with closing(Pool(3)) as pool: | |
filter_ = partial(_filter_import, include, exclude) | |
generate = partial(_generate_new_source, filter_) | |
for path, new_source in pool.imap(generate, _expand_paths(files)): | |
if not path: | |
continue | |
if args.write: | |
print("Writing:", path) | |
with open(path, "w") as h: | |
h.write(new_source) | |
else: | |
print(">>>", path) | |
print(new_source) | |
print("<<<") | |
def _generate_new_source(filter_, filepath): | |
try: | |
with open(filepath, "r") as h: | |
source_code = h.read() | |
source_ast = asttokens.ASTTokens(source_code, parse=True) | |
except Exception as err: | |
print("Failed to load {}: {}".format(filepath, err)) | |
return None, None | |
modifications = tuple(_get_modifications(source_ast.tree, {}, source_ast, filter_)) | |
if not modifications: | |
return None, None | |
new_source = asttokens.util.replace(source_code, modifications) | |
return filepath, new_source | |
def _get_modifications(parent, global_unused, parser, filter_): | |
local_imports = {} | |
local_names = set() | |
functions = [] | |
skip = set() | |
for child in _iter_scope(parent): | |
# Collect all imports in this scope | |
if isinstance(child, (ast.ImportFrom, ast.Import)): | |
for name in child.names: | |
local_imports[name.asname or name.name] = child | |
# Track what we will recurse into next | |
if isinstance(child, ast.FunctionDef): | |
functions.append(child) | |
tosearch = child.decorator_list | |
elif isinstance(child, ast.ClassDef): | |
tosearch = child.bases + child.decorator_list | |
else: | |
tosearch = [child] | |
# Look for used references so we can tell what is in this scope | |
for check in tosearch: | |
for n in asttokens.util.walk(check): | |
if n in skip: | |
continue | |
if isinstance(n, (ast.Attribute, ast.Name)): | |
text = parser.get_text(n).split(".") | |
for i in range(len(text)): | |
local_names.add(".".join(text[0 : i + 1])) | |
if isinstance(n, ast.Attribute): | |
# Skip diving further into attribute. We don't need the | |
# individual pieces of the stack of attributes | |
skip.add(n.value) | |
# Run through all imports that are out of this scope | |
# but in use. Bring them in! | |
# Track imports that are still unused | |
local_unused = {} | |
for name, node in global_unused.items(): | |
if name not in local_names: | |
local_unused[name] = node | |
continue | |
if not filter_(node, name): | |
continue | |
# Remove original import | |
start, end = parser.get_text_range(node) | |
yield start, end, "" | |
# Add import locally | |
tok = parent.body[0].first_token | |
if tok[0] == token.STRING: # Docstring | |
tok = parent.body[1].first_token | |
start = tok.startpos | |
yield start, start, "\n{0}# Import automatically moved local\n{0}{1}\n{0}\n{0}".format( | |
parser.text[start-1] * tok[2][1], _format_import(node, name) | |
) | |
# Add imports that are declared in this scope but unused | |
local_unused.update( | |
(name, node) for name, node in local_imports.items() if name not in local_names | |
) | |
# Build a new set of unused imports | |
for func in functions: | |
for item in _get_modifications(func, local_unused, parser, filter_): | |
yield item | |
def _expand_paths(paths): | |
for path in paths: | |
if os.path.isfile(path): | |
yield path | |
elif os.path.isdir(path): | |
for root, _, files in os.walk(path): | |
for f in files: | |
if f.endswith(".py"): | |
yield os.path.join(root, f) | |
def _filter_import(include, exclude, node, alias): | |
text = [n.name for n in node.names if (n.asname or n.name) == alias][0] | |
if isinstance(node, ast.ImportFrom): | |
text = "{}.{}".format(node.module, text) | |
if exclude and fnmatch(text, exclude): | |
return False | |
if include and not fnmatch(text, include): | |
return False | |
return True | |
def _format_import(node, alias): | |
name = [n for n in node.names if (n.asname or n.name) == alias][0] | |
text = "import {}".format(_format_name(name)) | |
if isinstance(node, ast.ImportFrom): | |
text = "from {} {}".format(node.module, text) | |
return text | |
def _format_name(node): | |
if node.asname: | |
return "{} as {}".format(node.name, node.asname) | |
return node.name | |
def _iter_scope(parent): | |
for child in parent.body: | |
yield child | |
if isinstance(child, ast.ClassDef): | |
for node in _iter_scope(child): | |
yield node | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser( | |
description="Move imports locally in python source files." | |
) | |
parser.add_argument("files", nargs="+", help="Paths to source files") | |
parser.add_argument( | |
"--include", "-i", help="Include only specified imports. Can use * wildcard." | |
) | |
parser.add_argument( | |
"--exclude", "-e", help="Exclude specified imports. Can use * wildcard." | |
) | |
parser.add_argument( | |
"--write", | |
"-w", | |
action="store_true", | |
default=False, | |
help="Write changes to the source files instead of printing", | |
) | |
args = parser.parse_args() | |
main(args.files, args.include, args.exclude, args.write) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment