Created
November 7, 2025 22:00
-
-
Save MathiasSven/6dab55936a51ded60b6cee384f5d7afa to your computer and use it in GitHub Desktop.
I learned afterwards that the problem I was tyringt o solve could be fixed by running `2to3`. But this may serve as an example for using `libcst` in the future.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from pathlib import Path | |
| from typing import cast | |
| import libcst as cst | |
| import libcst.matchers as m | |
| class ImportRewriter(cst.CSTTransformer): | |
| def __init__(self, module_names: set[str]): | |
| self.mod_match = m.Name(value=m.MatchIfTrue(module_names.__contains__)) | |
| def leave_ImportFrom(self, original_node, updated_node) -> cst.ImportFrom: | |
| assert not original_node.relative, "Patcher assumes all imports are absolute" | |
| if m.matches( | |
| original_node, | |
| m.ImportFrom(module=m.OneOf(self.mod_match, m.Attribute(value=self.mod_match))) | |
| ): | |
| assert isinstance(original_node.module, cst.Name), "Patcher assumes there are no subpackages" | |
| return original_node.with_changes(relative=[cst.Dot()]) | |
| return original_node | |
| def leave_Import( | |
| self, original_node, updated_node | |
| ) -> cst.Import | cst.ImportFrom | cst.FlattenSentinel[cst.Import | cst.ImportFrom]: | |
| import_aliases = cast(list[cst.ImportAlias], m.findall(original_node, m.ImportAlias())) | |
| rel_aliases, skip_aliases = [], [] | |
| for import_alias in import_aliases: | |
| if m.matches(import_alias.name, m.Attribute(value=self.mod_match)): | |
| raise AssertionError("Patcher assumes there are no subpackages") | |
| elif m.matches(import_alias.name, self.mod_match): | |
| rel_aliases.append(import_alias) | |
| else: | |
| skip_aliases.append(import_alias) | |
| if not rel_aliases: | |
| return original_node | |
| if skip_aliases: | |
| skip_aliases[-1] = skip_aliases[-1].with_changes(comma=cst.MaybeSentinel.DEFAULT) | |
| skip_imports = [original_node.with_changes(names=skip_aliases)] | |
| else: | |
| skip_imports = [] | |
| rel_aliases[-1] = rel_aliases[-1].with_changes(comma=cst.MaybeSentinel.DEFAULT) | |
| rel_import_froms = [ | |
| cst.ImportFrom( | |
| module=None, | |
| names=rel_aliases, | |
| semicolon=original_node.semicolon, | |
| whitespace_after_import=original_node.whitespace_after_import, | |
| relative=[cst.Dot()] | |
| ) | |
| ] | |
| return cst.FlattenSentinel(skip_imports + rel_import_froms) | |
| if __name__ == "__main__": | |
| module_paths = list(Path("./PADS").glob("*.py")) | |
| module_names = {path.stem for path in module_paths} | |
| transformer = ImportRewriter(module_names) | |
| injest = {path: path.read_text() for path in module_paths} | |
| for path, py_source in injest.items(): | |
| path.write_text(cst.parse_module(py_source).visit(transformer).code) | |
| import unittest | |
| from typing import NamedTuple | |
| from textwrap import dedent | |
| class Case(NamedTuple): | |
| module_names: set[str] | |
| input_src: str | |
| output_src: str | |
| class TestImportRewriter(unittest.TestCase): | |
| def test_rewrites(self): | |
| cases: tuple[Case, ...] = ( | |
| Case({"foo"}, "from foo import bar", "from .foo import bar"), | |
| Case({"bar"}, "from foo import bar", "from foo import bar"), | |
| Case({"bar"}, "from foo.bar import baz", "from foo.bar import baz"), | |
| Case({"bar"}, "import foo", "import foo"), | |
| Case({"foo"}, "import foo", "from . import foo"), | |
| Case({"foo"}, "import foo, bar, baz", "import bar, baz; from . import foo"), | |
| Case({"foo"}, "import foo, bar # comment", "import bar; from . import foo # comment"), | |
| Case({"bar", "baz"}, "import foo, bar, baz", "import foo; from . import bar, baz"), | |
| Case({"foo", "baz"}, "import foo, bar, baz", "import bar; from . import foo, baz"), | |
| Case( | |
| {"foo", "bar"}, | |
| dedent( | |
| """ | |
| from foo import a as a1 | |
| import foo as foo1, bar as bar1, baz | |
| import itertools | |
| import baz.foo | |
| from math import log10 | |
| """ | |
| ), | |
| dedent( | |
| """ | |
| from .foo import a as a1 | |
| import baz; from . import foo as foo1, bar as bar1 | |
| import itertools | |
| import baz.foo | |
| from math import log10 | |
| """ | |
| ), | |
| ), | |
| ) | |
| for test_case in cases: | |
| with self.subTest(test_case): | |
| source_tree = cst.parse_module(test_case.input_src) | |
| transformer = ImportRewriter(test_case.module_names) | |
| modified_tree = source_tree.visit(transformer) | |
| self.assertEqual(test_case.output_src, modified_tree.code) | |
| def test_failure_conditions(self): | |
| cases: tuple[Case, ...] = ( | |
| Case({""}, "from .foo import bar", ""), | |
| Case({""}, "from ..foo import *", ""), | |
| Case({"foo"}, "import foo.bar", ""), | |
| Case({"foo"}, "import bar.baz, foo.baz", ""), | |
| ) | |
| for i, test_case in enumerate(cases): | |
| with self.subTest(i): | |
| with self.assertRaises(AssertionError): | |
| source_tree = cst.parse_module(test_case.input_src) | |
| transformer = ImportRewriter(test_case.module_names) | |
| source_tree.visit(transformer) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment