-
-
Save ericgj/8af30c2a89278a2442625aa7c6bd18dc to your computer and use it in GitHub Desktop.
simple tagged union type matching in python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Derived from [fn.py](https://github.com/kachayev/fn.py) function 'curried' | |
Amended to fix wrapping error: cf. https://github.com/kachayev/fn.py/pull/75 | |
Copyright 2013 Alexey Kachayev | |
Under the Apache License, Version 2.0 | |
http://www.apache.org/licenses/LICENSE-2.0 | |
""" | |
from functools import partial, wraps, update_wrapper | |
from inspect import getargspec | |
def curry(func): | |
"""A decorator that makes the function curried | |
Usage example: | |
>>> @curry | |
... def sum5(a, b, c, d, e): | |
... return a + b + c + d + e | |
... | |
>>> sum5(1)(2)(3)(4)(5) | |
15 | |
>>> sum5(1, 2, 3)(4, 5) | |
15 | |
""" | |
@wraps(func) | |
def _curry(*args, **kwargs): | |
f = func | |
count = 0 | |
while isinstance(f, partial): | |
if f.args: | |
count += len(f.args) | |
f = f.func | |
spec = getargspec(f) | |
if count == len(spec.args) - len(args): | |
return func(*args, **kwargs) | |
para_func = partial(func, *args, **kwargs) | |
update_wrapper(para_func, f) | |
return curry(para_func) | |
return _curry |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from f import curry | |
@curry | |
def match(uniontype,cases,target): | |
""" | |
Return case matching target instance of union type. | |
Cases are expressed as dicts with types as keys and functions as values. | |
Note target values must be iterable. Typically target is a named tuple. | |
Union types are from the `typing` library (Python 3.5 stdlib or install from PyPI) | |
Usage example: | |
from typing import NamedTuple, Union, List | |
Ok = NamedTuple('Ok', [('message',unicode)]) | |
ClientErr = NamedTuple('ClientErr', [('message',unicode), ('code',int)]) | |
ServerErr = NamedTuple('ServerErr', [('message',unicode), ('code',int), ('backtrace', List[unicode])]) | |
Response = Union[Ok, ClientErr, ServerErr] | |
display_response = ( | |
match(Response, { | |
Ok: (lambda msg: "Everything ok: %s" % msg), | |
ClientErr: (lambda msg,code: "Oops, did you mean to do that? (%d %s)" % (code,msg)), | |
ServerErr: (lambda msg,code,backtrace: "Something bad happened: (%d %s)\n\n%s" % (code,msg, "\n".join(backtrace))) | |
}) | |
) | |
#... | |
response = Ok("beautiful") | |
display_response(response) # "Everything ok: beautiful" | |
You must either specify a case for every type in the union type, or include a case for `type(None)`, which will be | |
used as a fallback if no cases match the target (called with no parameters): | |
match(Response, { | |
Ok: (lambda msg: msg), | |
type(None): (lambda : "Something went wrong") | |
}) | |
""" | |
assert issubclass(target.__class__,uniontype), \ | |
"%s is not in union type" % target.__class__.__name__ | |
utypes = [] | |
if hasattr(uniontype,'__union_set_params__'): | |
utypes = uniontype.__union_set_params__ | |
else: | |
utypes = [uniontype] # in case where union type is flattened to single type | |
missing = [ | |
t.__name__ for t in utypes \ | |
if not (cases.has_key(type(None)) or cases.has_key(t)) | |
] | |
assert len(missing) == 0, \ | |
"No case found for the following type(s): %s" % ", ".join(missing) | |
fn = None | |
wildcard = False | |
try: | |
fn = ( | |
next( cases[klass] for klass in cases if isinstance(target,klass) ) | |
) | |
except StopIteration: | |
fn = cases.get(type(None),None) | |
wildcard = bool(fn) | |
# note should never happen due to type assertions above | |
if fn is None: | |
raise TypeError("No cases match %s" % target.__class__.__name__) | |
assert callable(fn), \ | |
"Matched case is not callable; check your cases" | |
return fn() if wildcard else fn( *(slot for slot in target) ) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment