Created
January 31, 2018 18:59
-
-
Save bheklilr/865c032700a24cabea4d9187380c5115 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import types | |
import typing | |
import ast | |
import inspect | |
class TypeViolationError(TypeError): | |
def __init__(self, variable_name, expected_type, actual_type): | |
super().__init__() | |
self.variable_name = variable_name | |
self.expected_type = expected_type | |
self.actual_type = actual_type | |
def __repr__(self): | |
return '<{}({}, {}, {})>'.format( | |
self.__class__.__name__, | |
self.variable_name, self.expected_type.__name__, self.actual_type.__name__, | |
) | |
def __str__(self): | |
return 'Expected {} to have type {}, got type {}'.format( | |
self.variable_name, self.expected_type.__name__, self.actual_type.__name__, | |
) | |
def load(name): | |
return ast.Name(id=name, ctx=ast.Load()) | |
def isinstance_ast(name, type_): | |
return ast.Call(func=load('isinstance'), args=[load(name), load(type_)], keywords=[]) | |
def raise_type_violation_error_ast(variable_name, expected_type): | |
return ast.Raise( | |
exc=ast.Call( | |
func=load('TypeViolationError'), | |
args=[ | |
ast.Str(s=variable_name), | |
load(expected_type), | |
ast.Call(func=load('type'), args=[load(variable_name)], keywords=[]) | |
], | |
keywords=[], | |
), | |
cause=None, | |
) | |
def not_ast(node): | |
return ast.UnaryOp(op=ast.Not(), operand=node) | |
class TypeCheckingVisitor(ast.NodeTransformer): | |
def visit_AnnAssign(self, node): | |
name = node.target.id | |
type_ = node.annotation.id | |
return [ | |
node, | |
ast.If( | |
test=not_ast(isinstance_ast(name, type_)), | |
body=[raise_type_violation_error_ast(name, type_)], | |
orelse=[], | |
) | |
] | |
def transform(visitor): | |
def deco(function): | |
module = ast.parse(''.join(inspect.getsourcelines(function)[0])) | |
gen = module.body[0] | |
visitor.visit(gen) | |
ast.fix_missing_locations(gen) | |
module_code = compile(module, '<string>', 'exec') | |
globs = globals() | |
globs['TypeViolationError'] = TypeViolationError | |
return types.FunctionType(module_code.co_consts[0], globs, function.__name__) | |
return deco | |
enforce = transform(TypeCheckingVisitor()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment