Created
September 15, 2019 20:11
-
-
Save agronholm/aa510a6155503165bef9ba1ae90bb1f1 to your computer and use it in GitHub Desktop.
Typeguard import hook for automatic code instrumentation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ast | |
import re | |
import sys | |
from importlib.machinery import SourceFileLoader | |
from importlib.abc import MetaPathFinder | |
from importlib.util import decode_source, cache_from_source | |
from typing import Iterable | |
from unittest.mock import patch | |
# The name of this function is magical | |
def _call_with_frames_removed(f, *args, **kwargs): | |
return f(*args, **kwargs) | |
def optimized_cache_from_source(path, debug_override=None): | |
return cache_from_source(path, debug_override, optimization='typeguard') | |
class TypeguardTransformer(ast.NodeVisitor): | |
def visit_Module(self, node): | |
node.body.insert(0, ast.Import(names=[ast.alias('typeguard', None)])) | |
self.generic_visit(node) | |
return node | |
def visit_FunctionDef(self, node): | |
node.decorator_list.append( | |
ast.Attribute(ast.Name(id='typeguard', ctx=ast.Load()), 'typechecked', ast.Load()) | |
) | |
return node | |
class TypeguardLoader(SourceFileLoader): | |
def source_to_code(self, data, path, *, _optimize=-1): | |
source = decode_source(data) | |
tree = _call_with_frames_removed(compile, source, path, 'exec', ast.PyCF_ONLY_AST, | |
dont_inherit=True, optimize=_optimize) | |
tree = TypeguardTransformer().visit(tree) | |
ast.fix_missing_locations(tree) | |
return _call_with_frames_removed(compile, tree, path, 'exec', | |
dont_inherit=True, optimize=_optimize) | |
def exec_module(self, module): | |
# Use a custom optimization marker – the import lock should make this monkey patch safe | |
with patch('importlib._bootstrap_external.cache_from_source', optimized_cache_from_source): | |
return super().exec_module(module) | |
class TypeguardFinder(MetaPathFinder): | |
def __init__(self, packages, original_pathfinder): | |
self._package_exprs = [re.compile(r'^%s\.?' % pkg) for pkg in packages] | |
self._original_pathfinder = original_pathfinder | |
def find_spec(self, fullname, path=None, target=None): | |
spec = self._original_pathfinder.find_spec(fullname, path, target) | |
if isinstance(spec.loader, SourceFileLoader): | |
if self.should_instrument(spec.loader.name): | |
spec.loader = TypeguardLoader(spec.loader.name, spec.loader.path) | |
return spec | |
def should_instrument(self, module_name: str) -> bool: | |
for expr in self._package_exprs: | |
if expr.match(module_name): | |
return True | |
return False | |
def install_import_hook(packages: Iterable[str]): | |
if isinstance(packages, str): | |
packages = [packages] | |
sys.meta_path[-1] = TypeguardFinder(packages, sys.meta_path[-1]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment