Created
April 9, 2018 05:25
-
-
Save treyhunner/7d4c946ef1a2963cf55951d97ebcdbb8 to your computer and use it in GitHub Desktop.
Code taken from David Beazley's PyCon Israel 2017 keynote
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Code based on talk by David Beazley's PyCon Israel keynote in 2017 | |
Watch the talk at https://www.youtube.com/watch?v=Je8TcRQcUgA | |
Usage:: | |
from contract import Base, PositiveInteger | |
dx: PositiveInteger | |
class Player(Base): | |
name: AnotherContract | |
x: PositiveInteger | |
y: PositiveInteger | |
def left(self, dx): | |
self.x -= dx | |
def right(self, dx): | |
self.x += dx | |
p = Player('Guido', 5, 6) | |
p.x = 23 | |
p.left(5) | |
p.left(-5) # Raises an exception | |
""" | |
from collections import ChainMap | |
from functools import wraps | |
from inspect import signature | |
_contracts = {} | |
class Contract: | |
def __init_subclass__(cls): | |
_contracts[cls.__name__] = cls | |
def __set__(self, instance, value): | |
self.check(value) | |
instance.__dict__[self.name] = value | |
def __set_name__(self, cls, name): | |
self.name = name | |
@classmethod | |
def check(cls, value): | |
pass | |
class Typed(Contract): | |
type = None | |
@classmethod | |
def check(cls, value): | |
assert isinstance(value, cls.type), f'Expected {cls.type}' | |
super().check(value) | |
class Positive(Contract): | |
@classmethod | |
def check(cls, value): | |
assert value > 0, 'Must be > 0' | |
super().check(value) | |
class Nonempty(Contract): | |
@classmethod | |
def check(cls, value): | |
assert len(value) > 0, 'Must be nonempty' | |
super().check(value) | |
class Integer(Typed): | |
type = int | |
class String(Typed): | |
type = str | |
class NonemptyString(String, Nonempty): | |
pass | |
class PositiveInteger(Integer, Positive): | |
pass | |
def checked(func): | |
sig = signature(func) | |
ann = ChainMap( | |
getattr(func, '__annotations__', {}), | |
func.__globals__.get('__annotations__', {}), | |
) | |
@wraps(func) | |
def wrapper(*args, **kwargs): | |
bound = sig.bind(*args, **kwargs) | |
for name, val in bound.arguments.items(): | |
if name in ann: | |
ann[name].check(val) | |
return func(*args, **kwargs) | |
return wrapper | |
class BaseMeta(type): | |
@classmethod | |
def __prepare__(cls, *args): | |
return ChainMap({}, _contracts, {'George': 4}) | |
def __new__(meta, name, bases, methods): | |
methods = methods.maps[0] | |
return super().__new__(meta, name, bases, methods) | |
class Base(metaclass=BaseMeta): | |
@classmethod | |
def __init_subclass__(cls): | |
# Instantiate the contracts | |
for name, val in cls.__dict__.items(): | |
if callable(val): | |
setattr(cls, name, checked(val)) | |
for name, val in cls.__annotations__.items(): | |
contract = val() | |
contract.__set_name__(cls, name) | |
setattr(cls, name, contract) | |
def __init__(self, *args): | |
ann = self.__annotations__ | |
assert len(args) == len(ann), f'Expected {len(ann)} arguments' | |
for name, val in zip(ann, args): | |
setattr(self, name, val) | |
def __repr__(self): | |
args = ','.join( | |
repr(getattr(self, name)) | |
for name in self.__annotations__ | |
) | |
return f'{type(self).__name__}({args})' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment