Skip to content

Instantly share code, notes, and snippets.

@agronholm
Created September 15, 2019 20:11
Show Gist options
  • Save agronholm/aa510a6155503165bef9ba1ae90bb1f1 to your computer and use it in GitHub Desktop.
Save agronholm/aa510a6155503165bef9ba1ae90bb1f1 to your computer and use it in GitHub Desktop.
Typeguard import hook for automatic code instrumentation
import ast
import re
import sys
from importlib.machinery import SourceFileLoader
from importlib.abc import MetaPathFinder
from importlib.util import decode_source, cache_from_source
from typing import Iterable
from unittest.mock import patch
# The name of this function is magical
def _call_with_frames_removed(f, *args, **kwargs):
return f(*args, **kwargs)
def optimized_cache_from_source(path, debug_override=None):
return cache_from_source(path, debug_override, optimization='typeguard')
class TypeguardTransformer(ast.NodeVisitor):
def visit_Module(self, node):
node.body.insert(0, ast.Import(names=[ast.alias('typeguard', None)]))
self.generic_visit(node)
return node
def visit_FunctionDef(self, node):
node.decorator_list.append(
ast.Attribute(ast.Name(id='typeguard', ctx=ast.Load()), 'typechecked', ast.Load())
)
return node
class TypeguardLoader(SourceFileLoader):
def source_to_code(self, data, path, *, _optimize=-1):
source = decode_source(data)
tree = _call_with_frames_removed(compile, source, path, 'exec', ast.PyCF_ONLY_AST,
dont_inherit=True, optimize=_optimize)
tree = TypeguardTransformer().visit(tree)
ast.fix_missing_locations(tree)
return _call_with_frames_removed(compile, tree, path, 'exec',
dont_inherit=True, optimize=_optimize)
def exec_module(self, module):
# Use a custom optimization marker – the import lock should make this monkey patch safe
with patch('importlib._bootstrap_external.cache_from_source', optimized_cache_from_source):
return super().exec_module(module)
class TypeguardFinder(MetaPathFinder):
def __init__(self, packages, original_pathfinder):
self._package_exprs = [re.compile(r'^%s\.?' % pkg) for pkg in packages]
self._original_pathfinder = original_pathfinder
def find_spec(self, fullname, path=None, target=None):
spec = self._original_pathfinder.find_spec(fullname, path, target)
if isinstance(spec.loader, SourceFileLoader):
if self.should_instrument(spec.loader.name):
spec.loader = TypeguardLoader(spec.loader.name, spec.loader.path)
return spec
def should_instrument(self, module_name: str) -> bool:
for expr in self._package_exprs:
if expr.match(module_name):
return True
return False
def install_import_hook(packages: Iterable[str]):
if isinstance(packages, str):
packages = [packages]
sys.meta_path[-1] = TypeguardFinder(packages, sys.meta_path[-1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment