Created
March 18, 2017 15:34
-
-
Save rob-smallshire/f4eeea5ae54a708fc819c46f4b6d26c9 to your computer and use it in GitHub Desktop.
Multiple class decorators
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from abc import ABC, abstractmethod | |
import functools | |
def invariant(predicate): | |
"""Create a class decorator which checks a class invariant. | |
Args: | |
predicate: A callable to which, after every method invocation, | |
the object on which the method was called will be passed. | |
The predicate should evaluate to True if the class invariant | |
has been maintained, or False if it has been violated. | |
Returns: | |
A class decorator for checking the class invariant tested by | |
the supplied predicate function. | |
""" | |
def invariant_checking_class_decorator(cls): | |
"""A class decorator for checking invariants.""" | |
method_names = [name for name, attr in vars(cls).items() if callable(attr)] | |
for name in method_names: | |
_wrap_method_with_invariant_checking_proxy(cls, name, predicate) | |
property_names = [name for name, attr in vars(cls).items() if isinstance(attr, PropertyDataDescriptor)] | |
for name in property_names: | |
_wrap_property_with_invariant_checking_proxy(cls, name, predicate) | |
return cls | |
return invariant_checking_class_decorator | |
def _wrap_method_with_invariant_checking_proxy(cls, name, predicate): | |
method = getattr(cls, name) | |
assert callable(method) | |
@functools.wraps(method) | |
def invariant_checking_method_decorator(self, *args, **kwargs): | |
result = method(self, *args, **kwargs) | |
if not predicate(self): | |
raise RuntimeError("Class invariant {!r} violated for {!r}".format(predicate.__doc__, self)) | |
return result | |
setattr(cls, name, invariant_checking_method_decorator) | |
pass | |
class PropertyDataDescriptor(ABC): | |
@abstractmethod | |
def __get__(self, instance, owner): | |
raise NotImplemented | |
@abstractmethod | |
def __set__(self, instance, value): | |
raise NotImplemented | |
@abstractmethod | |
def __delete__(self, instance): | |
raise NotImplemented | |
@property | |
@abstractmethod | |
def __isabstractmethod__(self): | |
raise NotImplemented | |
PropertyDataDescriptor.register(property) | |
class InvariantCheckingPropertyProxy(PropertyDataDescriptor): | |
def __init__(self, referent, predicate): | |
self._referent = referent | |
self._predicate = predicate | |
def __get__(self, instance, owner): | |
if instance is None: | |
return self._referent | |
result = self._referent.__get__(instance, owner) | |
if not self._predicate(instance): | |
raise RuntimeError("Class invariant {!r} violated for {!r}".format(self._predicate.__doc__, instance)) | |
return result | |
def __set__(self, instance, value): | |
result = self._referent.__set__(instance, value) | |
if not self._predicate(instance): | |
raise RuntimeError("Class invariant {!r} violated for {!r}".format(self._predicate.__doc__, instance)) | |
return result | |
def __delete__(self, instance): | |
result = self._referent.__delete__(instance) | |
if not self._predicate(instance): | |
raise RuntimeError("Class invariant {!r} violated for {!r}".format(self._predicate.__doc__, instance)) | |
return result | |
@property | |
def __isabstractmethod__(self): | |
return self._referent.__isabstractmethod__ | |
def _wrap_property_with_invariant_checking_proxy(cls, name, predicate): | |
prop = getattr(cls, name) | |
assert isinstance(prop, PropertyDataDescriptor) | |
invariant_checking_proxy = InvariantCheckingPropertyProxy(prop, predicate) | |
#setattr(cls, name, "Hello!") # !!! This "works" | |
setattr(cls, name, invariant_checking_proxy) # !!! This doesn't work | |
assert getattr(cls, name) is invariant_checking_proxy | |
assert getattr(cls, name) is not property | |
def not_below_absolute_zero(temperature): | |
"""Temperature not below absolute zero""" | |
return temperature._kelvin >= 0 | |
def below_absolute_hot(temperature): | |
"""Temperature below absolute hot""" | |
# See http://en.wikipedia.org/wiki/Absolute_hot | |
return temperature._kelvin <= 1.416785e32 | |
@invariant(below_absolute_hot) # !!! Only the invariant from the outermost decorator is respected | |
@invariant(not_below_absolute_zero) # !!! The invariant from the innermost decorator is ignored | |
class Temperature: | |
def __init__(self, kelvin): | |
self._kelvin = kelvin | |
def get_kelvin(self): | |
return self._kelvin | |
def set_kelvin(self, value): | |
self._kelvin = value | |
@property | |
def celsius(self): | |
return self._kelvin - 273.15 | |
@celsius.setter | |
def celsius(self, value): | |
self._kelvin = value + 273.15 | |
@property | |
def fahrenheit(self): | |
return self._kelvin * 9/5 - 459.67 | |
@fahrenheit.setter | |
def fahrenheit(self, value): | |
self._kelvin = (value + 459.67) * 5/9 | |
def main(): | |
t = Temperature(42.0) | |
t.celsius = 30 | |
t.celsius = -300 # !!! This should fail, but doesn't | |
t.celsius = 1e34 # !!! This does fail, as it should | |
pass | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment