Skip to content

Instantly share code, notes, and snippets.

@treyhunner
Created April 9, 2018 05:25
Show Gist options
  • Save treyhunner/7d4c946ef1a2963cf55951d97ebcdbb8 to your computer and use it in GitHub Desktop.
Save treyhunner/7d4c946ef1a2963cf55951d97ebcdbb8 to your computer and use it in GitHub Desktop.
Code taken from David Beazley's PyCon Israel 2017 keynote
"""
Code based on talk by David Beazley's PyCon Israel keynote in 2017
Watch the talk at https://www.youtube.com/watch?v=Je8TcRQcUgA
Usage::
from contract import Base, PositiveInteger
dx: PositiveInteger
class Player(Base):
name: AnotherContract
x: PositiveInteger
y: PositiveInteger
def left(self, dx):
self.x -= dx
def right(self, dx):
self.x += dx
p = Player('Guido', 5, 6)
p.x = 23
p.left(5)
p.left(-5) # Raises an exception
"""
from collections import ChainMap
from functools import wraps
from inspect import signature
_contracts = {}
class Contract:
def __init_subclass__(cls):
_contracts[cls.__name__] = cls
def __set__(self, instance, value):
self.check(value)
instance.__dict__[self.name] = value
def __set_name__(self, cls, name):
self.name = name
@classmethod
def check(cls, value):
pass
class Typed(Contract):
type = None
@classmethod
def check(cls, value):
assert isinstance(value, cls.type), f'Expected {cls.type}'
super().check(value)
class Positive(Contract):
@classmethod
def check(cls, value):
assert value > 0, 'Must be > 0'
super().check(value)
class Nonempty(Contract):
@classmethod
def check(cls, value):
assert len(value) > 0, 'Must be nonempty'
super().check(value)
class Integer(Typed):
type = int
class String(Typed):
type = str
class NonemptyString(String, Nonempty):
pass
class PositiveInteger(Integer, Positive):
pass
def checked(func):
sig = signature(func)
ann = ChainMap(
getattr(func, '__annotations__', {}),
func.__globals__.get('__annotations__', {}),
)
@wraps(func)
def wrapper(*args, **kwargs):
bound = sig.bind(*args, **kwargs)
for name, val in bound.arguments.items():
if name in ann:
ann[name].check(val)
return func(*args, **kwargs)
return wrapper
class BaseMeta(type):
@classmethod
def __prepare__(cls, *args):
return ChainMap({}, _contracts, {'George': 4})
def __new__(meta, name, bases, methods):
methods = methods.maps[0]
return super().__new__(meta, name, bases, methods)
class Base(metaclass=BaseMeta):
@classmethod
def __init_subclass__(cls):
# Instantiate the contracts
for name, val in cls.__dict__.items():
if callable(val):
setattr(cls, name, checked(val))
for name, val in cls.__annotations__.items():
contract = val()
contract.__set_name__(cls, name)
setattr(cls, name, contract)
def __init__(self, *args):
ann = self.__annotations__
assert len(args) == len(ann), f'Expected {len(ann)} arguments'
for name, val in zip(ann, args):
setattr(self, name, val)
def __repr__(self):
args = ','.join(
repr(getattr(self, name))
for name in self.__annotations__
)
return f'{type(self).__name__}({args})'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment