Skip to content

Instantly share code, notes, and snippets.

@MathiasSven
Created November 7, 2025 22:00
Show Gist options
  • Select an option

  • Save MathiasSven/6dab55936a51ded60b6cee384f5d7afa to your computer and use it in GitHub Desktop.

Select an option

Save MathiasSven/6dab55936a51ded60b6cee384f5d7afa to your computer and use it in GitHub Desktop.
I learned afterwards that the problem I was tyringt o solve could be fixed by running `2to3`. But this may serve as an example for using `libcst` in the future.
from pathlib import Path
from typing import cast
import libcst as cst
import libcst.matchers as m
class ImportRewriter(cst.CSTTransformer):
def __init__(self, module_names: set[str]):
self.mod_match = m.Name(value=m.MatchIfTrue(module_names.__contains__))
def leave_ImportFrom(self, original_node, updated_node) -> cst.ImportFrom:
assert not original_node.relative, "Patcher assumes all imports are absolute"
if m.matches(
original_node,
m.ImportFrom(module=m.OneOf(self.mod_match, m.Attribute(value=self.mod_match)))
):
assert isinstance(original_node.module, cst.Name), "Patcher assumes there are no subpackages"
return original_node.with_changes(relative=[cst.Dot()])
return original_node
def leave_Import(
self, original_node, updated_node
) -> cst.Import | cst.ImportFrom | cst.FlattenSentinel[cst.Import | cst.ImportFrom]:
import_aliases = cast(list[cst.ImportAlias], m.findall(original_node, m.ImportAlias()))
rel_aliases, skip_aliases = [], []
for import_alias in import_aliases:
if m.matches(import_alias.name, m.Attribute(value=self.mod_match)):
raise AssertionError("Patcher assumes there are no subpackages")
elif m.matches(import_alias.name, self.mod_match):
rel_aliases.append(import_alias)
else:
skip_aliases.append(import_alias)
if not rel_aliases:
return original_node
if skip_aliases:
skip_aliases[-1] = skip_aliases[-1].with_changes(comma=cst.MaybeSentinel.DEFAULT)
skip_imports = [original_node.with_changes(names=skip_aliases)]
else:
skip_imports = []
rel_aliases[-1] = rel_aliases[-1].with_changes(comma=cst.MaybeSentinel.DEFAULT)
rel_import_froms = [
cst.ImportFrom(
module=None,
names=rel_aliases,
semicolon=original_node.semicolon,
whitespace_after_import=original_node.whitespace_after_import,
relative=[cst.Dot()]
)
]
return cst.FlattenSentinel(skip_imports + rel_import_froms)
if __name__ == "__main__":
module_paths = list(Path("./PADS").glob("*.py"))
module_names = {path.stem for path in module_paths}
transformer = ImportRewriter(module_names)
injest = {path: path.read_text() for path in module_paths}
for path, py_source in injest.items():
path.write_text(cst.parse_module(py_source).visit(transformer).code)
import unittest
from typing import NamedTuple
from textwrap import dedent
class Case(NamedTuple):
module_names: set[str]
input_src: str
output_src: str
class TestImportRewriter(unittest.TestCase):
def test_rewrites(self):
cases: tuple[Case, ...] = (
Case({"foo"}, "from foo import bar", "from .foo import bar"),
Case({"bar"}, "from foo import bar", "from foo import bar"),
Case({"bar"}, "from foo.bar import baz", "from foo.bar import baz"),
Case({"bar"}, "import foo", "import foo"),
Case({"foo"}, "import foo", "from . import foo"),
Case({"foo"}, "import foo, bar, baz", "import bar, baz; from . import foo"),
Case({"foo"}, "import foo, bar # comment", "import bar; from . import foo # comment"),
Case({"bar", "baz"}, "import foo, bar, baz", "import foo; from . import bar, baz"),
Case({"foo", "baz"}, "import foo, bar, baz", "import bar; from . import foo, baz"),
Case(
{"foo", "bar"},
dedent(
"""
from foo import a as a1
import foo as foo1, bar as bar1, baz
import itertools
import baz.foo
from math import log10
"""
),
dedent(
"""
from .foo import a as a1
import baz; from . import foo as foo1, bar as bar1
import itertools
import baz.foo
from math import log10
"""
),
),
)
for test_case in cases:
with self.subTest(test_case):
source_tree = cst.parse_module(test_case.input_src)
transformer = ImportRewriter(test_case.module_names)
modified_tree = source_tree.visit(transformer)
self.assertEqual(test_case.output_src, modified_tree.code)
def test_failure_conditions(self):
cases: tuple[Case, ...] = (
Case({""}, "from .foo import bar", ""),
Case({""}, "from ..foo import *", ""),
Case({"foo"}, "import foo.bar", ""),
Case({"foo"}, "import bar.baz, foo.baz", ""),
)
for i, test_case in enumerate(cases):
with self.subTest(i):
with self.assertRaises(AssertionError):
source_tree = cst.parse_module(test_case.input_src)
transformer = ImportRewriter(test_case.module_names)
source_tree.visit(transformer)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment