Skip to content

Instantly share code, notes, and snippets.

@JnyJny
Last active July 27, 2023 08:46
Show Gist options
  • Save JnyJny/05d33a4b89d120c6fb50ed7fa2ce5978 to your computer and use it in GitHub Desktop.
Save JnyJny/05d33a4b89d120c6fb50ed7fa2ce5978 to your computer and use it in GitHub Desktop.
Find Python Import Statements
"""Find import and import from statements
"""
import ast
from pathlib import Path
class ImportRecord:
@classmethod
def toplevel_imports(cls, root: Path, prune: list[str] = None) -> list[str]:
prune = set(prune or [])
toplevels = []
for record in cls.from_root(root):
for name in record.imports:
toplevels.append(name.split(".")[0])
return list(set(toplevels) - prune)
@classmethod
def from_root(cls, root: Path) -> list["ImportRecord"]:
return [cls(path) for path in root.rglob("*.py")]
def __init__(self, path: Path) -> None:
self.path = path
def __str__(self) -> str:
""" """
out = [str(self.path)]
for node in self.import_nodes:
if isinstance(node, ast.Import):
if not node.names[0].asname:
out.append(f"\t {node.lineno:3d} import {node.names[0].name}")
else:
out.append(
f"\t {node.lineno:3d} import {node.names[0].name} as {node.names[0].asname}"
)
continue
if isinstance(node, ast.ImportFrom):
module = "." if not node.module else node.module
out.append(
f"\t{node.level:2d} {node.lineno:3d} from {module} import {', '.join(str(alias.name) for alias in node.names)}"
)
continue
raise ValueError(f"Expected ast.Import or ast.ImportFrom, got {node}")
return "\n".join(out)
@staticmethod
def is_toplevel_import(node: "ast.Node") -> bool:
if not isinstance(node, (ast.Import, ast.ImportFrom)):
return False
try:
return node.level == 0
except AttributeError:
pass
return True
@property
def text(self) -> str:
try:
return self._text
except AttributeError:
pass
self._text = self.path.read_text()
return self._text
@property
def tree(self) -> "ast.tree":
try:
return self._tree
except AttributeError:
pass
self._tree = ast.parse(self.text, str(self.path))
return self._tree
@property
def import_nodes(self) -> list[ast.Import | ast.ImportFrom]:
try:
return self._import_nodes
except AttributeError:
pass
self._import_nodes = [
node
for node in ast.walk(self.tree)
if ImportRecord.is_toplevel_import(node)
]
return self._import_nodes
@property
def imports(self) -> list[str]:
try:
return self._imports
except AttributeError:
pass
self._imports = []
for node in self.import_nodes:
try:
self._imports.append(node.module)
except AttributeError:
for alias in node.names:
self._imports.append(alias.name)
return self._imports
if __name__ == "__main__":
for name in ImportRecord.toplevel_imports(Path.cwd()):
print(name)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment