Skip to content

Instantly share code, notes, and snippets.

@datavudeja
Forked from dillonhicks/contract.py
Created August 13, 2025 12:36
Show Gist options
  • Save datavudeja/bd492a0dd023b0d1c3ec9fb3ce7fa17a to your computer and use it in GitHub Desktop.
Save datavudeja/bd492a0dd023b0d1c3ec9fb3ce7fa17a to your computer and use it in GitHub Desktop.
Contracts, Interfaces, and Implementations
from __future__ import absolute_import, print_function
from abc import ABCMeta
import functools
import inspect
import logging
from collections import defaultdict, namedtuple
import boltons.typeutils
import six
__all__ = (
'implemented_by',
'implements',
'Interface',
'signature',
'enforced',
'IllegalArgumentTypeError',
'IllegalReturnTypeError',
'IllegalGenericTypeError',
'Enforcement',
'check_arguments'
)
logger = logging.getLogger(__name__)
def immutable(clazz_name, **attrs):
"""Immutable object factory backed by a namedtuple. It is only truely
immutable if the attributes are immutable as well. Since the
backing datatype is a namedtuple, the retruned object has a lot of
free functionality.
Assume: `Color = immutable('Color', blue=1, red=2, green=3, yellow=4)`
- Dot access attributes: `Color.blue`
- A sane default __str__ and __repr__ implementation::
>>> print(Color)
Color(blue=1, green=3, yellow=4, red=2)
- Iterable and indexable::
>>> for color in Color: print(color)
1
3
4
2
- Immutable attributes::
>>> Color.blue = 'blue'
AttributeError: can't set attribute
- Automatic shallow dict conversion (not recursive)::
>>> print(Color._asdict())
OrderedDict([('blue', 1), ('green', 3), ('yellow', 4), ('red', 2)])
:param clazz_name: The name given to the namedtuple, this is
useful for str(obj) and debugging. Giving a name to the
immutable object makes it clear what it is. Otherwise it is
just a super tuple.
:param attrs: The attributes of the immutable object
"""
clazz = namedtuple(clazz_name, attrs.keys())
return clazz(**attrs)
def sentinel(name):
"""Returns a named sentinel object.
Often used for _missing implementations or implementation specific
nulls.
>>> data = dict(value=None)
>>> _missing = sentinel('Missing')
>>> print(data.get('value', _missing) is _missing)
False
>>> print(data.get('value', _missing) is None)
True
"""
return boltons.typeutils.make_sentinel(var_name=name)
SIGNATURE_TYPES_ANNOTATION = '__signature_types_annotation__'
GENERIC_TYPE_ANNOTATION = '__generic_types_annotation__'
RETURN_TYPE_ANNOTATION = '__return_type_annotation__'
_GENERIC_FINALIZATION_FIELD = '__finalized_generic__'
Enforcement = immutable('Enforcement', none=0, debug=10, exception=100)
_DEFAULT_ENFORCEMENT_POLICY = Enforcement.exception
_MissingTypeSpec = sentinel('MissingTypeSpec')
def implemented_by(IClass):
"""return a predicate function for testing if an instance implements a
class"""
if not issubclass(IClass, Interface) and IClass is not Interface:
raise TypeError('IClass is not an interface type')
def is_implemented_by(obj):
return isinstance(obj, IClass)
return is_implemented_by
def implements(obj, IClass):
if not issubclass(IClass, Interface) and IClass is not Interface:
raise TypeError('IClass is not an interface type')
return isinstance(obj, IClass)
class InterfaceMeta(ABCMeta):
"""Extensible Metaclass which can be used to customize Interface behavior."""
pass
class Interface(six.with_metaclass(InterfaceMeta)):
"""Base class for interfaces, just a tag for now to tell if something
should be an interface."""
pass
class GenericInterfaceMeta(InterfaceMeta):
"""Extends the InterfaceMeta metaclass to add Generic typing to
interfaces.
Example:
class IClient(GenericInterface):
@abstractproperty
def call(self, request):
pass
class JSONClient(IClient[str]):
# implementation requiring call and returning str
class BinaryClient(IClient[bytes]):
# implementation requiring call and returning bytes
@contract.enforced()
@contract.signature(client=IClient)
def inject_headers(client):
# Accepts all IClient subclasses
@contract.enforced()
@contract.signature(client=IClient[str])
def json_to_dict(client):
# Accepts only JSONClient
"""
# Capture unique generic classes so we do not create multiples
# and so IGeneric[<type>] will not create multiple types.
_generic_registry = {}
def __getitem__(self, params):
"""Emulate py3.5 Generic type semantics by overloading [] operator to
specify generic type arguments
Example:
class Starchy(object): pass
class IVegetable(GenericInterface): pass
class Potato(IVegetable[Starchy]): pass
"""
# Check the final tag on the object to ensure we do not allow
# multiple generic subclasses. Using the docstring as an
# example, this allows IVegetable[<type>] but prevents
# Potato[<type>].
is_final = getattr(self, _GENERIC_FINALIZATION_FIELD, False) is True
if is_final:
raise IllegalGenericTypeError('Cannot create generic type from concrete class')
if not isinstance(params, tuple):
params = (params,)
key = (self, params)
if key in self._generic_registry:
return self._generic_registry[key]
# name = ClassName<type1,...>
name = '{}<{}>'.format(self.__name__, ','.join(Type.__name__ for Type in params))
Generic = type(name, (self, Interface), {})
setattr(Generic, _GENERIC_FINALIZATION_FIELD, True)
setattr(Generic, GENERIC_TYPE_ANNOTATION, params)
self._generic_registry[key] = Generic
return Generic
class GenericInterface(six.with_metaclass(GenericInterfaceMeta)):
"""Subclass to allow for specifying generic types in interfaces (class
IntList(IList[int]): pass).
See also:: GenericInterfaceMeta
"""
@classmethod
def generic_types(cls):
"""Return the tuple of generic types"""
return getattr(cls, GENERIC_TYPE_ANNOTATION, tuple())
class signature(object):
"""Decorator to create a typed function signature that can be used
with the @enforced decorator to validate argument type
preconditions at runtime.
"""
def __init__(self, **param_types):
self.param_types = param_types
def __call__(self, func):
setattr(func, SIGNATURE_TYPES_ANNOTATION, self.param_types)
return func
class returns(object):
"""Decorator to specify a single simple return type of a callable"""
def __init__(self, Type):
self.Type = Type
def __call__(self, func):
setattr(func, RETURN_TYPE_ANNOTATION, self.Type)
return func
_IllegalArgument = namedtuple(
'IllegalArgument', (
'pos',
'value',
'type',
'expected'
))
_IllegalReturnValue = namedtuple(
'IllegalReturnValue', (
'value',
'type',
'expected'
))
_TypedSignature = namedtuple(
'TypedArgaSpec', (
'args',
'varargs',
'keywords',
'defaults',
'param_types',
'return_type',
))
class IllegalArgumentTypeError(Exception):
"""Error thrown when a type does not match an expected signature"""
def __init__(self, pos, value, type, expected):
arg = _IllegalArgument(pos, value, type, expected)
super(IllegalArgumentTypeError, self).__init__(str(arg))
class IllegalReturnTypeError(Exception):
"""Error thrown when a type does not match an expected signature"""
def __init__(self, value, type, expected):
value = _IllegalReturnValue(value, type, expected)
super(IllegalReturnTypeError, self).__init__(str(value))
class IllegalGenericTypeError(Exception):
pass
class enforced(object):
"""Decorator which uses the _TypedSignature extension to ArgSpec to
validate call arguments."""
def __init__(self, level=_DEFAULT_ENFORCEMENT_POLICY):
self.level = level
def __call__(self, func):
if self.level == Enforcement.none:
return func
@functools.wraps(func)
def on_call_type_validation_wrapper(*args, **kwargs):
sig = typed_signature(func)
if sig.args is not None:
for i, arg in enumerate(sig.args):
typespec = sig.param_types[arg]
if typespec is _MissingTypeSpec:
continue
value = args[i]
if not _arg_matches(value, typespec):
if self.level == Enforcement.exception:
raise IllegalArgumentTypeError(i, value, type(value), typespec)
elif self.level == Enforcement.debug:
logger.warn('Type mismatch %s', _IllegalArgument(i, value, type(value), typespec))
else:
raise RuntimeError()
return_value = func(*args, **kwargs)
if sig.return_type is None or _arg_matches(return_value, sig.return_type):
return return_value
elif self.level == Enforcement.exception:
raise IllegalReturnTypeError(return_value, type(return_value), sig.return_type)
elif self.level == Enforcement.debug:
logger.warn('Type mismatch %s', _IllegalArgument(i, value, type(value), typespec))
else:
raise RuntimeError()
return on_call_type_validation_wrapper
def check_arguments():
"""If you are unsure, use @enforced instead.
Validate arguments by inspecting a functions call frame.
Note: This is ~1000x slower than @enforce due to needing to
inspect the stack.
"""
frame = inspect.currentframe().f_back
func = frame.f_globals[frame.f_code.co_name]
arginfo = inspect.getargvalues(frame)
sig = typed_signature(func)
if sig.args is not None:
for i, arg in enumerate(sig.args):
typespec = sig.param_types[arg]
if typespec is _MissingTypeSpec:
continue
value = arginfo.locals[arg]
if not _arg_matches(value, typespec):
if not _arg_matches(value, typespec):
raise IllegalArgumentTypeError(i, value, type(value), typespec)
def _arg_matches(arg, Type):
return isinstance(arg, Type)
def typed_signature(func):
"""Return the type enhanced signature of the function"""
argspec = inspect.getargspec(func)
sig_args = argspec._asdict()
types = getattr(func, SIGNATURE_TYPES_ANNOTATION, {})
return_type = getattr(func, RETURN_TYPE_ANNOTATION, None)
sig_args['param_types'] = defaultdict(lambda: _MissingTypeSpec, types)
sig_args['return_type'] = return_type
return _TypedSignature(**sig_args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment