Created
April 7, 2021 10:47
-
-
Save pfmoore/858b6b5b5f40e02f4693626ab7226ce0 to your computer and use it in GitHub Desktop.
Test Python import redirects
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import contextlib | |
import runpy | |
import site | |
import sys | |
import pytest | |
# ---------------------------------------------------------------------- | |
import sys | |
import importlib.abc | |
import importlib.util | |
class RedirectingFinder(importlib.abc.MetaPathFinder): | |
def __init__(self): | |
self.redirects = {} | |
def redirect(self, fullname, path): | |
if path.is_dir(): | |
init = path / "__init__.py" | |
if init.is_file(): | |
self.redirects[fullname] = str(init) | |
else: | |
self.redirects.setdefault(fullname, []).append(str(path)) | |
else: | |
self.redirects[fullname] = str(path) | |
def find_spec(self, fullname, path, target=None): | |
redir = self.redirects.get(fullname) | |
print(f"{fullname} -> {redir}") | |
spec = None | |
if isinstance(redir, list): | |
# Namespace package | |
spec = importlib.util.spec_from_loader(fullname, None) | |
if spec: | |
spec.submodule_search_locations = redir | |
else: | |
spec = importlib.util.spec_from_file_location(fullname, str(redir)) | |
return spec | |
def install(self): | |
for hook in sys.meta_path: | |
if hook == self: | |
return | |
sys.meta_path.append(self) | |
finder = RedirectingFinder() | |
# ---------------------------------------------------------------------- | |
@contextlib.contextmanager | |
def save_import_state(): | |
orig_modules = set(sys.modules.keys()) | |
orig_path = list(sys.path) | |
orig_meta_path = list(sys.meta_path) | |
orig_path_hooks = list(sys.path_hooks) | |
orig_path_importer_cache = sys.path_importer_cache | |
try: | |
yield | |
finally: | |
remove = [key for key in sys.modules if key not in orig_modules] | |
for key in remove: | |
del sys.modules[key] | |
sys.path[:] = orig_path | |
sys.meta_path[:] = orig_meta_path | |
sys.path_hooks[:] = orig_path_hooks | |
sys.path_importer_cache.clear() | |
sys.path_importer_cache.update(orig_path_importer_cache) | |
def build(target, structure): | |
target.mkdir(exist_ok=True, parents=True) | |
for name, content in structure.items(): | |
path = target / name | |
if isinstance(content, str): | |
path.write_text(content, encoding="utf-8") | |
else: | |
build(path, content) | |
class NoRedir: | |
def __init__(self, project): | |
self.project = project | |
def redirect(self, name, path): | |
# Do nothing, but error if the parameters are inconsistent | |
assert path.parent == self.project | |
assert path.stem == name | |
def activate(self): | |
sys.path.append(str(self.project)) | |
class MetaRedir: | |
def __init__(self, project): | |
self.project = project | |
# It is crucial that we use the *same* finder | |
# here, as namespaces must be registered together | |
self.finder = finder | |
def redirect(self, name, path): | |
self.finder.redirect(name, path) | |
def activate(self): | |
self.finder.install() | |
@pytest.mark.parametrize("method", [NoRedir, MetaRedir]) | |
def test_simple_imports(tmp_path, method): | |
project = tmp_path / "project" | |
project_files = { | |
"mod.py": "val = 42", | |
"pkg": { | |
"__init__.py": "val = 42", | |
"sub.py": "val = 42", | |
}, | |
} | |
build(project, project_files) | |
m = method(project) | |
m.redirect("mod", project / "mod.py") | |
m.redirect("pkg", project / "pkg") | |
with save_import_state(): | |
m.activate() | |
import mod | |
assert mod.val == 42 | |
import pkg | |
assert pkg.val == 42 | |
import pkg.sub | |
assert pkg.sub.val == 42 | |
@pytest.mark.parametrize("method", [NoRedir, MetaRedir]) | |
def test_main_call(tmp_path, method): | |
project = tmp_path / "project" | |
project_files = { | |
"pkg": { | |
"__init__.py": "val = 0", | |
"__main__.py": "val = 42", | |
}, | |
} | |
build(project, project_files) | |
m = method(project) | |
m.redirect("pkg", project / "pkg") | |
with save_import_state(): | |
m.activate() | |
import pkg | |
assert pkg.val == 0 | |
g = runpy.run_module("pkg") | |
assert g["val"] == 42 | |
@pytest.mark.parametrize("method", [NoRedir, MetaRedir]) | |
def test_namespace(tmp_path, method): | |
project1 = tmp_path / "project1" | |
project1_files = { | |
"ns": { | |
"foo1": { "__init__.py": "val = 42" }, | |
"bar1": { "__init__.py": "val = 42" }, | |
"baz1": { | |
"__init__.py": "val = 0", | |
"__main__.py": "val = 42", | |
}, | |
}, | |
} | |
build(project1, project1_files) | |
project2 = tmp_path / "project2" | |
project2_files = { | |
"ns": { | |
"foo2": { "__init__.py": "val = 42" }, | |
"bar2": { "__init__.py": "val = 42" }, | |
"baz2": { | |
"__init__.py": "val = 0", | |
"__main__.py": "val = 42", | |
}, | |
}, | |
} | |
build(project2, project2_files) | |
m1 = method(project1) | |
m2 = method(project2) | |
m1.redirect("ns", project1 / "ns") | |
m2.redirect("ns", project2 / "ns") | |
with save_import_state(): | |
m1.activate() | |
m2.activate() | |
import ns.foo1 | |
assert ns.foo1.val == 42 | |
import ns.bar1 | |
assert ns.bar1.val == 42 | |
import ns.baz1 | |
assert ns.baz1.val == 0 | |
g = runpy.run_module("ns.baz1") | |
assert g["val"] == 42 | |
import ns.foo2 | |
assert ns.foo2.val == 42 | |
import ns.bar2 | |
assert ns.bar2.val == 42 | |
import ns.baz2 | |
assert ns.baz2.val == 0 | |
g = runpy.run_module("ns.baz2") | |
assert g["val"] == 42 | |
@pytest.mark.parametrize("method", [NoRedir, MetaRedir]) | |
def test_namespace_can_be_extended(tmp_path, method): | |
project = tmp_path / "project" | |
project_files = { | |
"ns": { | |
"foo1": { "__init__.py": "val = 42" }, | |
"bar1": { "__init__.py": "val = 42" }, | |
"baz1": { | |
"__init__.py": "val = 0", | |
"__main__.py": "val = 42", | |
}, | |
}, | |
} | |
build(project, project_files) | |
extra = tmp_path / "extra" | |
extra_files = { | |
"ns": { | |
"foo2": { "__init__.py": "val = 42" }, | |
"bar2": { "__init__.py": "val = 42" }, | |
"baz2": { | |
"__init__.py": "val = 0", | |
"__main__.py": "val = 42", | |
}, | |
}, | |
} | |
build(extra, extra_files) | |
m = method(project) | |
m.redirect("ns", project / "ns") | |
with save_import_state(): | |
m.activate() | |
sys.path.append(str(extra)) | |
import ns.foo1 | |
assert ns.foo1.val == 42 | |
import ns.bar1 | |
assert ns.bar1.val == 42 | |
import ns.baz1 | |
assert ns.baz1.val == 0 | |
g = runpy.run_module("ns.baz1") | |
assert g["val"] == 42 | |
import ns.foo2 | |
assert ns.foo2.val == 42 | |
import ns.bar2 | |
assert ns.bar2.val == 42 | |
import ns.baz2 | |
assert ns.baz2.val == 0 | |
g = runpy.run_module("ns.baz2") | |
assert g["val"] == 42 | |
if method == MetaRedir: | |
assert False, sys.modules["ns"].__loader__ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment