Skip to content

Instantly share code, notes, and snippets.

@rob-smallshire
Created March 18, 2017 15:34
Show Gist options
  • Save rob-smallshire/f4eeea5ae54a708fc819c46f4b6d26c9 to your computer and use it in GitHub Desktop.
Save rob-smallshire/f4eeea5ae54a708fc819c46f4b6d26c9 to your computer and use it in GitHub Desktop.
Multiple class decorators
from abc import ABC, abstractmethod
import functools
def invariant(predicate):
"""Create a class decorator which checks a class invariant.
Args:
predicate: A callable to which, after every method invocation,
the object on which the method was called will be passed.
The predicate should evaluate to True if the class invariant
has been maintained, or False if it has been violated.
Returns:
A class decorator for checking the class invariant tested by
the supplied predicate function.
"""
def invariant_checking_class_decorator(cls):
"""A class decorator for checking invariants."""
method_names = [name for name, attr in vars(cls).items() if callable(attr)]
for name in method_names:
_wrap_method_with_invariant_checking_proxy(cls, name, predicate)
property_names = [name for name, attr in vars(cls).items() if isinstance(attr, PropertyDataDescriptor)]
for name in property_names:
_wrap_property_with_invariant_checking_proxy(cls, name, predicate)
return cls
return invariant_checking_class_decorator
def _wrap_method_with_invariant_checking_proxy(cls, name, predicate):
method = getattr(cls, name)
assert callable(method)
@functools.wraps(method)
def invariant_checking_method_decorator(self, *args, **kwargs):
result = method(self, *args, **kwargs)
if not predicate(self):
raise RuntimeError("Class invariant {!r} violated for {!r}".format(predicate.__doc__, self))
return result
setattr(cls, name, invariant_checking_method_decorator)
pass
class PropertyDataDescriptor(ABC):
@abstractmethod
def __get__(self, instance, owner):
raise NotImplemented
@abstractmethod
def __set__(self, instance, value):
raise NotImplemented
@abstractmethod
def __delete__(self, instance):
raise NotImplemented
@property
@abstractmethod
def __isabstractmethod__(self):
raise NotImplemented
PropertyDataDescriptor.register(property)
class InvariantCheckingPropertyProxy(PropertyDataDescriptor):
def __init__(self, referent, predicate):
self._referent = referent
self._predicate = predicate
def __get__(self, instance, owner):
if instance is None:
return self._referent
result = self._referent.__get__(instance, owner)
if not self._predicate(instance):
raise RuntimeError("Class invariant {!r} violated for {!r}".format(self._predicate.__doc__, instance))
return result
def __set__(self, instance, value):
result = self._referent.__set__(instance, value)
if not self._predicate(instance):
raise RuntimeError("Class invariant {!r} violated for {!r}".format(self._predicate.__doc__, instance))
return result
def __delete__(self, instance):
result = self._referent.__delete__(instance)
if not self._predicate(instance):
raise RuntimeError("Class invariant {!r} violated for {!r}".format(self._predicate.__doc__, instance))
return result
@property
def __isabstractmethod__(self):
return self._referent.__isabstractmethod__
def _wrap_property_with_invariant_checking_proxy(cls, name, predicate):
prop = getattr(cls, name)
assert isinstance(prop, PropertyDataDescriptor)
invariant_checking_proxy = InvariantCheckingPropertyProxy(prop, predicate)
#setattr(cls, name, "Hello!") # !!! This "works"
setattr(cls, name, invariant_checking_proxy) # !!! This doesn't work
assert getattr(cls, name) is invariant_checking_proxy
assert getattr(cls, name) is not property
def not_below_absolute_zero(temperature):
"""Temperature not below absolute zero"""
return temperature._kelvin >= 0
def below_absolute_hot(temperature):
"""Temperature below absolute hot"""
# See http://en.wikipedia.org/wiki/Absolute_hot
return temperature._kelvin <= 1.416785e32
@invariant(below_absolute_hot) # !!! Only the invariant from the outermost decorator is respected
@invariant(not_below_absolute_zero) # !!! The invariant from the innermost decorator is ignored
class Temperature:
def __init__(self, kelvin):
self._kelvin = kelvin
def get_kelvin(self):
return self._kelvin
def set_kelvin(self, value):
self._kelvin = value
@property
def celsius(self):
return self._kelvin - 273.15
@celsius.setter
def celsius(self, value):
self._kelvin = value + 273.15
@property
def fahrenheit(self):
return self._kelvin * 9/5 - 459.67
@fahrenheit.setter
def fahrenheit(self, value):
self._kelvin = (value + 459.67) * 5/9
def main():
t = Temperature(42.0)
t.celsius = 30
t.celsius = -300 # !!! This should fail, but doesn't
t.celsius = 1e34 # !!! This does fail, as it should
pass
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment