Skip to content

Instantly share code, notes, and snippets.

@barbuza
Created May 11, 2011 14:05
Show Gist options
  • Save barbuza/966501 to your computer and use it in GitHub Desktop.
Save barbuza/966501 to your computer and use it in GitHub Desktop.
validate arbitrary data structures in python
# -*- coding: utf-8 -*-
import functools
import inspect
"""
Contract is tiny library for data validation
It provides several primitives to validate complex data structures
Look at doctests for usage examples
"""
__all__ = ("ContractValidationError", "Contract", "AnyC", "IntC", "StringC",
"ListC", "DictC", "OrC", "NullC", "FloatC", "EnumC", "CallableC",
"CallC", "ForwardC", "BoolC", "guard", )
class ContractValidationError(Exception):
"""
Basic contract validation error
"""
def __init__(self, msg, name=None):
message = msg if not name else "%s: %s" % (name, msg)
super(ContractValidationError, self).__init__(message)
self.msg = msg
self.name = name
class ContractMeta(type):
"""
Metaclass for contracts to make using "|" operator possible not only
on instances but on classes
>>> IntC | StringC
<OrC(<IntC>, <StringC>)>
>>> IntC | StringC | NullC
<OrC(<IntC>, <StringC>, <NullC>)>
"""
def __or__(cls, other):
return cls() | other
class Contract(object):
"""
Base class for contracts, provides only one method for
contract validation failure reporting
"""
__metaclass__ = ContractMeta
def check(self, value):
"""
Implement this method in Contract subclasses
"""
cls = "%s.%s" % (type(self).__module__, type(self).__name__)
raise NotImplementedError("method check is not implemented in"
" '%s'" % cls)
def _failure(self, message):
"""
Shortcut method for raising validation error
"""
raise ContractValidationError(message)
def _contract(self, contract):
"""
Helper for complex contracts, takes contract instance or class
and returns contract instance
"""
if isinstance(contract, Contract):
return contract
elif issubclass(contract, Contract):
return contract()
else:
raise RuntimeError("%r should be instance or subclass"
" of Contract" % contract)
def __or__(self, other):
return OrC(self, other)
class AnyC(Contract):
"""
>>> AnyC()
<AnyC>
>>> AnyC().check(object())
"""
def check(self, value):
pass
def __repr__(self):
return "<AnyC>"
class OrCMeta(ContractMeta):
"""
Allows to use "<<" operator on OrC class
>>> OrC << IntC << StringC
<OrC(<IntC>, <StringC>)>
"""
def __lshift__(cls, other):
return cls() << other
class OrC(Contract):
"""
>>> nullString = OrC(StringC, NullC)
>>> nullString
<OrC(<StringC>, <NullC>)>
>>> nullString.check(None)
>>> nullString.check("test")
>>> nullString.check(1)
Traceback (most recent call last):
...
ContractValidationError: no one contract matches
"""
__metaclass__ = OrCMeta
def __init__(self, *contracts):
self.contracts = map(self._contract, contracts)
def check(self, value):
for contract in self.contracts:
try:
contract.check(value)
except ContractValidationError:
pass
else:
return
self._failure("no one contract matches")
def __lshift__(self, contract):
self.contracts.append(self._contract(contract))
return self
def __or__(self, contract):
self << contract
return self
def __repr__(self):
return "<OrC(%s)>" % (", ".join(map(repr, self.contracts)))
class NullC(Contract):
"""
>>> NullC()
<NullC>
>>> NullC().check(None)
>>> NullC().check(1)
Traceback (most recent call last):
...
ContractValidationError: value should be None
"""
def check(self, value):
if value is not None:
self._failure("value should be None")
def __repr__(self):
return "<NullC>"
class BoolC(Contract):
"""
>>> BoolC()
<BoolC>
>>> BoolC().check(True)
>>> BoolC().check(False)
>>> BoolC().check(1)
Traceback (most recent call last):
...
ContractValidationError: value should be True or False
"""
def check(self, value):
if not isinstance(value, bool):
self._failure("value should be True or False")
def __repr__(self):
return "<BoolC>"
class NumberCMeta(ContractMeta):
"""
Allows slicing syntax for min and max arguments for
number contracts
>>> IntC[1:]
<IntC(min=1)>
>>> IntC[1:10]
<IntC(min=1, max=10)>
>>> IntC[:10]
<IntC(max=10)>
>>> FloatC[1:]
<FloatC(min=1)>
"""
def __getitem__(self, slice_):
return self(min_=slice_.start, max_=slice_.stop)
class IntC(Contract):
"""
>>> IntC()
<IntC>
>>> IntC(min_=1)
<IntC(min=1)>
>>> IntC(max_=10)
<IntC(max=10)>
>>> IntC(min_=1, max_=10)
<IntC(min=1, max=10)>
>>> IntC().check("foo")
Traceback (most recent call last):
...
ContractValidationError: value is not int
>>> IntC(min_=1).check(1)
>>> IntC(min_=2).check(1)
Traceback (most recent call last):
...
ContractValidationError: value is less than 2
>>> IntC(max_=10).check(5)
>>> IntC(max_=3).check(5)
Traceback (most recent call last):
...
ContractValidationError: value is greater than 3
"""
__metaclass__ = NumberCMeta
def __init__(self, min_=None, max_=None):
self.min = min_
self.max = max_
def check(self, value):
if not isinstance(value, int):
self._failure("value is not int")
if self.min is not None and value < self.min:
self._failure("value is less than %s" % self.min)
if self.max is not None and value > self.max:
self._failure("value is greater than %s" % self.max)
def __repr__(self):
r = "<IntC"
options = []
if self.min is not None:
options.append("min=%s" % self.min)
if self.max is not None:
options.append("max=%s" % self.max)
if options:
r += "(%s)" % (", ".join(options))
r += ">"
return r
class FloatC(Contract):
"""
>>> FloatC()
<FloatC>
>>> FloatC(min_=1)
<FloatC(min=1)>
>>> FloatC(max_=10)
<FloatC(max=10)>
>>> FloatC(min_=1, max_=10)
<FloatC(min=1, max=10)>
>>> FloatC().check(1.0)
>>> FloatC().check(1)
Traceback (most recent call last):
...
ContractValidationError: value is not float
>>> FloatC(min_=2).check(3.0)
>>> FloatC(min_=2).check(1.0)
Traceback (most recent call last):
...
ContractValidationError: value is less than 2
>>> FloatC(max_=10).check(5.0)
>>> FloatC(max_=3).check(5.0)
Traceback (most recent call last):
...
ContractValidationError: value is greater than 3
"""
__metaclass__ = NumberCMeta
def __init__(self, min_=None, max_=None):
self.min = min_
self.max = max_
def check(self, value):
if not isinstance(value, float):
self._failure("value is not float")
if self.min is not None and value < self.min:
self._failure("value is less than %s" % self.min)
if self.max is not None and value > self.max:
self._failure("value is greater than %s" % self.max)
def __repr__(self):
r = "<FloatC"
options = []
if self.min is not None:
options.append("min=%s" % self.min)
if self.max is not None:
options.append("max=%s" % self.max)
if options:
r += "(%s)" % (", ".join(options))
r += ">"
return r
class StringC(Contract):
"""
>>> StringC()
<StringC>
>>> StringC(allow_blank=True)
<StringC(blank)>
>>> StringC().check("foo")
>>> StringC().check("")
Traceback (most recent call last):
...
ContractValidationError: blank value is not allowed
>>> StringC(allow_blank=True).check("")
>>> StringC().check(1)
Traceback (most recent call last):
...
ContractValidationError: value is not string
"""
def __init__(self, allow_blank=False):
self.allow_blank = allow_blank
def check(self, value):
if not isinstance(value, basestring):
self._failure("value is not string")
if not self.allow_blank and len(value) is 0:
self._failure("blank value is not allowed")
def __repr__(self):
return "<StringC(blank)>" if self.allow_blank else "<StringC>"
class SquareBracketsMeta(ContractMeta):
"""
Allows usage of square brackets for ListC initialization
>>> ListC[IntC]
<ListC(<IntC>)>
>>> ListC[IntC, 1:]
<ListC(min_length=1 | <IntC>)>
>>> ListC[:10, IntC]
<ListC(max_length=10 | <IntC>)>
>>> ListC[1:10]
Traceback (most recent call last):
...
RuntimeError: Contract is required for ListC initialization
"""
def __getitem__(self, args):
slice_ = None
contract = None
if not isinstance(args, tuple):
args = (args, )
for arg in args:
if isinstance(arg, slice):
slice_ = arg
elif isinstance(arg, Contract) or issubclass(arg, Contract):
contract = arg
if not contract:
raise RuntimeError("Contract is required for ListC initialization")
if slice_:
return self(contract, min_length=slice_.start or 0,
max_length=slice_.stop)
return self(contract)
class ListC(Contract):
"""
>>> ListC(IntC)
<ListC(<IntC>)>
>>> ListC(IntC, min_length=1)
<ListC(min_length=1 | <IntC>)>
>>> ListC(IntC, min_length=1, max_length=10)
<ListC(min_length=1, max_length=10 | <IntC>)>
>>> ListC(IntC).check(1)
Traceback (most recent call last):
...
ContractValidationError: value is not list
>>> ListC(IntC).check([1, 2, 3])
>>> ListC(StringC).check(["foo", "bar", "spam"])
>>> ListC(IntC).check([1, 2, 3.0])
Traceback (most recent call last):
...
ContractValidationError: 2: value is not int
>>> ListC(IntC, min_length=1).check([1, 2, 3])
>>> ListC(IntC, min_length=1).check([])
Traceback (most recent call last):
...
ContractValidationError: list length is less than 1
>>> ListC(IntC, max_length=2).check([1, 2])
>>> ListC(IntC, max_length=2).check([1, 2, 3])
Traceback (most recent call last):
...
ContractValidationError: list length is greater than 2
"""
__metaclass__ = SquareBracketsMeta
def __init__(self, contract, min_length=0, max_length=None):
self.contract = self._contract(contract)
self.min_length = min_length
self.max_length = max_length
def check(self, value):
if not isinstance(value, list):
self._failure("value is not list")
if len(value) < self.min_length:
self._failure("list length is less than %s" % self.min_length)
if self.max_length is not None and len(value) > self.max_length:
self._failure("list length is greater than %s" % self.max_length)
for index, item in enumerate(value):
try:
self.contract.check(item)
except ContractValidationError as err:
name = "%i.%s" % (index, err.name) if err.name else str(index)
raise ContractValidationError(err.msg, name)
def __repr__(self):
r = "<ListC("
options = []
if self.min_length:
options.append("min_length=%s" % self.min_length)
if self.max_length:
options.append("max_length=%s" % self.max_length)
r += ", ".join(options)
if options:
r += " | "
r += repr(self.contract)
r += ")>"
return r
class DictC(Contract):
"""
>>> contract = DictC(foo=IntC, bar=StringC)
>>> contract.check({"foo": 1, "bar": "spam"})
>>> contract.check({"foo": 1, "bar": 2})
Traceback (most recent call last):
...
ContractValidationError: bar: value is not string
>>> contract.check({"foo": 1})
Traceback (most recent call last):
...
ContractValidationError: bar is required
>>> contract.check({"foo": 1, "bar": "spam", "eggs": None})
Traceback (most recent call last):
...
ContractValidationError: eggs is not allowed key
>>> contract.allow_extra("eggs")
<DictC(extras=(eggs) | bar=<StringC>, foo=<IntC>)>
>>> contract.check({"foo": 1, "bar": "spam", "eggs": None})
>>> contract.check({"foo": 1, "bar": "spam"})
>>> contract.check({"foo": 1, "bar": "spam", "ham": 100})
Traceback (most recent call last):
...
ContractValidationError: ham is not allowed key
>>> contract.allow_extra("*")
<DictC(any, extras=(eggs) | bar=<StringC>, foo=<IntC>)>
>>> contract.check({"foo": 1, "bar": "spam", "ham": 100})
>>> contract.check({"foo": 1, "bar": "spam", "ham": 100, "baz": None})
>>> contract.check({"foo": 1, "ham": 100, "baz": None})
Traceback (most recent call last):
...
ContractValidationError: bar is required
>>> contract.allow_optionals("bar")
<DictC(any, extras=(eggs), optionals=(bar) | bar=<StringC>, foo=<IntC>)>
>>> contract.check({"foo": 1, "ham": 100, "baz": None})
>>> contract.check({"bar": 1, "ham": 100, "baz": None})
Traceback (most recent call last):
...
ContractValidationError: foo is required
>>> contract.check({"foo": 1, "bar": 1, "ham": 100, "baz": None})
Traceback (most recent call last):
...
ContractValidationError: bar: value is not string
"""
def __init__(self, **contracts):
self.optionals = []
self.extras = []
self.allow_any = False
self.contracts = {}
for key, contract in contracts.items():
self.contracts[key] = self._contract(contract)
def allow_extra(self, *names):
for name in names:
if name == "*":
self.allow_any = True
else:
self.extras.append(name)
return self
def allow_optionals(self, *names):
for name in names:
if name == "*":
self.optionals = self.contracts.keys()
else:
self.optionals.append(name)
return self
def check(self, value):
if not isinstance(value, dict):
self._failure("value is not dict")
self.check_presence(value)
map(self.check_item, value.items())
def check_presence(self, value):
for key in self.contracts:
if key not in self.optionals and key not in value:
self._failure("%s is required" % key)
def check_item(self, item):
key, value = item
if key in self.contracts:
try:
self.contracts[key].check(value)
except ContractValidationError as err:
name = "%s.%s" % (key, err.name) if err.name else key
raise ContractValidationError(err.msg, name)
elif not self.allow_any and key not in self.extras:
self._failure("%s is not allowed key" % key)
def __repr__(self):
r = "<DictC("
options = []
if self.allow_any:
options.append("any")
if self.extras:
options.append("extras=(%s)" % (", ".join(self.extras)))
if self.optionals:
options.append("optionals=(%s)" % (", ".join(self.optionals)))
r += ", ".join(options)
if options:
r += " | "
options = []
for key in sorted(self.contracts.keys()):
options.append("%s=%r" % (key, self.contracts[key]))
r += ", ".join(options)
r += ")>"
return r
class EnumC(Contract):
"""
>>> contract = EnumC("foo", "bar", 1)
>>> contract
<EnumC('foo', 'bar', 1)>
>>> contract.check("foo")
>>> contract.check(1)
>>> contract.check(2)
Traceback (most recent call last):
...
ContractValidationError: value doesn't match any variant
"""
def __init__(self, *variants):
self.variants = variants[:]
def check(self, value):
if value not in self.variants:
self._failure("value doesn't match any variant")
def __repr__(self):
return "<EnumC(%s)>" % (", ".join(map(repr, self.variants)))
class CallableC(Contract):
"""
>>> CallableC().check(lambda: 1)
>>> CallableC().check(1)
Traceback (most recent call last):
...
ContractValidationError: value is not callable
"""
def check(self, value):
if not callable(value):
self._failure("value is not callable")
def __repr__(self):
return "<CallableC>"
class CallC(Contract):
"""
>>> def validator(value):
... if value != "foo":
... return "I want only foo!"
...
>>> contract = CallC(validator)
>>> contract
<CallC(validator)>
>>> contract.check("foo")
>>> contract.check("bar")
Traceback (most recent call last):
...
ContractValidationError: I want only foo!
"""
def __init__(self, fn):
if not callable(fn):
raise RuntimeError("CallC argument should be callable")
argspec = inspect.getargspec(fn)
if len(argspec.args) - len(argspec.defaults or []) > 1:
raise RuntimeError("CallC argument should be"
" one argument function")
self.fn = fn
def check(self, value):
error = self.fn(value)
if error is not None:
self._failure(error)
def __repr__(self):
return "<CallC(%s)>" % self.fn.__name__
class ForwardC(Contract):
"""
>>> nodeC = ForwardC()
>>> nodeC << DictC(name=StringC, children=ListC[nodeC])
>>> nodeC
<ForwardC(<DictC(children=<ListC(<recur>)>, name=<StringC>)>)>
>>> nodeC.check({"name": "foo", "children": []})
>>> nodeC.check({"name": "foo", "children": [1]})
Traceback (most recent call last):
...
ContractValidationError: children.0: value is not dict
>>> nodeC.check({"name": "foo", "children": [ \
{"name": "bar", "children": []} \
]})
"""
def __init__(self):
self.contract = None
self._recur_repr = False
def __lshift__(self, contract):
if self.contract:
raise RuntimeError("contract for ForwardC is already specified")
self.contract = self._contract(contract)
def check(self, value):
self.contract.check(value)
def __repr__(self):
# XXX not threadsafe
if self._recur_repr:
return "<recur>"
self._recur_repr = True
r = "<ForwardC(%r)>" % self.contract
self._recur_repr = False
return r
class GuardValidationError(Exception):
"""
Raised when guarded function gets invalid arguments,
inherits error message from corresponding ContractValidationError
"""
pass
def guard(**kwargs):
"""
Decorator for protecting function with contracts
>>> @guard(a=StringC)
... def fn(a, b, c=None):
... return (a, b, c)
...
>>> fn("foo", "bar")
('foo', 'bar', None)
>>> fn(1, "bar")
Traceback (most recent call last):
...
GuardValidationError: value is not string
>>> @guard(c=IntC)
... def fn(a, b, c=None):
... return (a, b, c)
...
>>> fn(1, 2)
(1, 2, None)
>>> fn(1, 2, "foo")
Traceback (most recent call last):
...
GuardValidationError: value is not int
"""
contracts = {}
for name, contract in kwargs.items():
contracts[name] = contract if isinstance(contract, Contract) \
else contract()
def wrapper(fn):
argspec = inspect.getargspec(fn)
@functools.wraps(fn)
def decor(*args, **kwargs):
try:
for argname, value in zip(argspec.args, args) + kwargs.items():
if argname in contracts:
contracts[argname].check(value)
except ContractValidationError as (errno, ):
raise GuardValidationError(errno)
return fn(*args, **kwargs)
guards = []
for name, contract in contracts.items():
guards.append("%s=%r" % (name, contract))
decor.__doc__ = "guarded with %s\n\n" % (", ".join(guards)) + \
(decor.__doc__ or "")
return decor
return wrapper
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment