-
-
Save akirayu101/8a213d439428ed30b8cf to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict | |
class BadMatch(NameError): | |
"""Exception when your args don't match a pattern""" | |
pass | |
class Any(object): | |
""" | |
>>> 'wutup' == Any() | |
True | |
""" | |
def __eq__(mcs, _other): | |
return True | |
Any = Any() | |
class OfType: | |
""" | |
>>> 3 == OfType(int, str, bool) | |
True | |
>>> 'ok' == OfType(int) | |
False | |
""" | |
def __init__(self, *types): | |
self.type = types | |
def __eq__(self, other): | |
return isinstance(other, self.type) | |
class Where: | |
""" | |
>>> 'ok' == Where(str.isupper) | |
False | |
>>> [] == Where(len) | |
False | |
""" | |
def __init__(self, predicate): | |
self.predicate = predicate | |
def __eq__(self, other): | |
try: | |
return bool(self.predicate(other)) | |
except: | |
return False | |
class WhereNot(Where): | |
def __eq__(self, other): | |
return not Where.__eq__(self, other) | |
class PatternMatcher(object): | |
""" | |
Keeps a dict of lists where key is the function name | |
and val is a list of functions with that name. | |
Whenever it decorates a function, it adds that function | |
to its dict of lists, then replace it with a new function | |
that calls the right function in the list based on passed | |
arguments. | |
""" | |
def __init__(self): | |
self.funcs = defaultdict(list) | |
def find_match(self, name): | |
"""Return a function that knows how to call the right | |
function based on the args passed in. | |
""" | |
my_funcs = self.funcs[name] | |
def wrapper(*args): | |
# TODO: can we handle **kwargs?? | |
for function in my_funcs: | |
if len(args) == len(function.__defaults__): | |
if all(passed == spec for (passed, spec) in zip(args, function.__defaults__)): | |
return function(*args) | |
else: | |
raise BadMatch("function `{0}` has no match for args ({1})".format(name, args)) | |
return wrapper | |
def __call__(self, func): | |
"""Decorator: add to func to self.funcs""" | |
keyword_args = func.__defaults__ or 0 | |
if func.__code__.co_argcount != len(keyword_args): | |
raise SyntaxError("Every argument must have default parameter for pattern matching") | |
self.funcs[func.__code__.co_name].append(func) | |
return self.find_match(func.__code__.co_name) | |
ifmatches = PatternMatcher() | |
if __name__ == "__main__": | |
# pylint has NO IDEA what I'm doing here | |
# it's all like 'OMGWTF?!' | |
@ifmatches | |
def my_func(test="hey"): | |
print("this is my_func with test=hey") | |
@ifmatches | |
def my_func(test=2): | |
print("this is my_func with test=2") | |
@ifmatches | |
def my_func(test=OfType(int)): | |
print("this is my_func with int test!=2") | |
@ifmatches | |
def my_func(test=Any): | |
print("this is my_func with a non-int") | |
@ifmatches | |
def psum(l=[]): | |
return 0 | |
@ifmatches | |
def psum(l=[Any]): | |
return l[0] | |
@ifmatches | |
def psum(l=Where(lambda x: len(x) > 1)): | |
head, tail = l[0], l[1:] | |
return head + psum(tail) | |
@ifmatches | |
def count_letters(string=""): | |
return 0 | |
@ifmatches | |
def count_letters(string=OfType(str)): | |
return count_letters(string[0], string[1:]) | |
@ifmatches | |
def count_letters(x=Where(str.isalpha), xs=OfType(str)): | |
return 1 + count_letters(xs) | |
@ifmatches | |
def count_letters(x=WhereNot(str.isalpha), xs=OfType(str)): | |
return count_letters(xs) | |
assert count_letters("ACD123") == 3 | |
assert count_letters("") == 0 | |
assert count_letters("123424") == 0 | |
try: | |
count_letters(1) | |
assert False | |
except BadMatch: | |
assert True | |
import doctest | |
doctest.testmod() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
* Generate classes like WhereNot automatically with class decorator? | |
* Haskell style deconstruction, perhaps like: | |
@pmatch | |
def count(first_arg=Match(x=Head, *xs=Rest)): | |
print x, xs | |
Will require some nasty hacks like https://github.com/smcq/python-inject | |
* Implement boolean operators for stuff like: | |
@pmatch | |
def example(name=OfType(str) & Where(str.isalpha)): | |
... |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment