Created
April 4, 2015 22:12
-
-
Save miikka/71a82ee8c13ecc3cd868 to your computer and use it in GitHub Desktop.
Pattern matching in Python
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""This is a sketch of implementing pattern matching of lists. | |
""" | |
from collections import namedtuple | |
# Destructuring lists | |
# | |
# In a destructuring pattern: | |
# | |
# * V('foo') matches a value and assigns it to foo | |
# * VV('foo') matches the rest of the list and assigns it to foo | |
# * X ignores a value | |
# * XX ignores the rest of the list | |
# * constant matches a constant | |
V = namedtuple('V', ['name']) | |
VV = namedtuple('VV', ['name']) | |
X = object() | |
XX = object() | |
def destruct(pattern, value): | |
result = {} | |
count = 0 | |
for idx, (p, v) in enumerate(zip(pattern, value)): | |
if isinstance(p, V): | |
result[p.name] = v | |
elif isinstance(p, VV) or p == XX: | |
break | |
elif p == X: | |
pass | |
elif isinstance(p, list) and isinstance(v, list): | |
result.update(destruct(p, v)) | |
elif p != v: | |
return None | |
count = count + 1 | |
if len(pattern) == count + 1: | |
p = pattern[count] | |
if isinstance(p, VV): | |
result[p.name] = value[count:] | |
elif not isinstance(p, XX): | |
return None | |
elif count != len(value): | |
return None | |
return result | |
# Pattern matching | |
class EmptyIterator(object): | |
def __iter__(self): return self | |
def __next__(self): raise StopIteration | |
class MatchResult(object): | |
def __init__(self, result): | |
self.__dict__ = result | |
def __repr__(self): | |
return 'MatchResult({})'.format(repr(self.__dict__)) | |
class NoMatchingPattern(Exception): | |
pass | |
class pattern(object): | |
def __init__(self, value): | |
self.matched = False | |
self.value = value | |
def __enter__(self): | |
return self | |
def __exit__(self, *exc): | |
if not self.matched: | |
raise NoMatchingPattern(self.value) | |
def match(self, pattern): | |
x = destruct(pattern, self.value) | |
if x is None or self.matched: | |
return EmptyIterator() | |
self.matched = True | |
return [MatchResult(x)] | |
# Example | |
def fibo(n): | |
with pattern([n]) as p: | |
for _ in p.match([0]): return 0 | |
for _ in p.match([1]): return 1 | |
for x in p.match([V('n')]): | |
return fibo(x.n - 1) + fibo(x.n - 2) | |
def flatten(xs): | |
with pattern(xs) as p: | |
for x in p.match([]): | |
return [] | |
for x in p.match([[VV('inner')], VV('outer')]): | |
return x.inner + flatten(x.outer) | |
if __name__ == '__main__': | |
print('fibo(7) =', fibo(7)) | |
print(flatten([[1, 2], [3, 4], [5]])) | |
print("\nThere will be an NoMatchingPattern exception:") | |
with pattern([1]) as p: | |
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment