Skip to content

Instantly share code, notes, and snippets.

@rob-smallshire
Created May 18, 2015 15:16
Show Gist options
  • Select an option

  • Save rob-smallshire/b7eaa80344f1a12de52e to your computer and use it in GitHub Desktop.

Select an option

Save rob-smallshire/b7eaa80344f1a12de52e to your computer and use it in GitHub Desktop.
A decorator for checking that class invariants are established and maintained.
import functools
def invariant(predicate):
"""Create a class decorator which checks a class invariant.
Args:
predicate: A callable to which, after every method invocation,
the object on which the method was called will be passed.
The predicate should evaluate to True if the class invariant
has been maintained, or False if it has been violated.
Returns:
A class decorator for checking the class invariant tested by
the supplied predicate function.
"""
def invariant_checking_class_decorator(cls):
"""A class decorator for checking invariants."""
method_names = [name for name, attr in vars(cls).items() if callable(attr)]
for name in method_names:
_wrap_method_with_invariant_checking_proxy(cls, name, predicate)
property_names = [name for name, attr in vars(cls).items() if isinstance(attr, property)]
for name in property_names:
_wrap_property_with_invariant_checking_proxy(cls, name, predicate)
return cls
return invariant_checking_class_decorator
def _wrap_method_with_invariant_checking_proxy(cls, name, predicate):
method = getattr(cls, name)
assert callable(method)
@functools.wraps(method)
def invariant_checking_method_decorator(self, *args, **kwargs):
result = method(self, *args, **kwargs)
if not predicate(self):
raise RuntimeError("Class invariant {!r} violated for {!r}".format(predicate.__doc__, self))
return result
setattr(cls, name, invariant_checking_method_decorator)
class InvariantCheckingPropertyProxy:
def __init__(self, referent, predicate):
self._referent = referent
self._predicate = predicate
def __get__(self, instance, owner):
if instance is None:
return self._referent
result = self._referent.__get__(instance, owner)
if not self._predicate(instance):
raise RuntimeError("Class invariant {!r} violated for {!r}".format(self._predicate.__doc__, instance))
return result
def __set__(self, instance, value):
result = self._referent.__set__(instance, value)
if not self._predicate(instance):
raise RuntimeError("Class invariant {!r} violated for {!r}".format(self._predicate.__doc__, instance))
return result
def __delete__(self, instance):
result = self._referent.__delete__(instance)
if not self._predicate(instance):
raise RuntimeError("Class invariant {!r} violated for {!r}".format(self._predicate.__doc__, instance))
return result
def _wrap_property_with_invariant_checking_proxy(cls, name, predicate):
prop = getattr(cls, name)
assert isinstance(prop, property)
invariant_checking_proxy = InvariantCheckingPropertyProxy(prop, predicate)
setattr(cls, name, invariant_checking_proxy)
def not_below_absolute_zero(temperature):
"""Temperature not below absolute zero"""
return temperature._kelvin >= 0
@invariant(not_below_absolute_zero)
class Temperature:
def __init__(self, kelvin):
self._kelvin = kelvin
def get_kelvin(self):
return self._kelvin
def set_kelvin(self, value):
self._kelvin = value
@property
def celsius(self):
return self._kelvin - 273.15
@celsius.setter
def celsius(self, value):
self._kelvin = value + 273.15
@property
def fahrenheit(self):
return self._kelvin * 9/5 - 459.67
@fahrenheit.setter
def fahrenheit(self, value):
self._kelvin = (value + 459.67) * 5/9
def main():
t = Temperature(42.0)
t.celsius = 30
t.celsius = -300
pass
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment