Created
September 26, 2023 11:45
-
-
Save mbillingr/305577546c3ec6d2e5d6ca511dc8f77a to your computer and use it in GitHub Desktop.
simple evaluator with pattern matching
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ast | |
import dataclasses | |
from typing import Any, Optional, TypeAlias | |
Expr: TypeAlias = Any | |
class Context: | |
def __init__(self, env=None): | |
self.env = env or {} | |
def extend(self, env): | |
return Context(env=self.env | env) | |
def lookup(self, name): | |
return self.env[name] | |
def evaluate(expr: Expr, ctx: Context) -> Any: | |
while True: | |
match expr: | |
case list(): | |
return [evaluate(x, ctx) for x in expr] | |
case ("apply", rator, *rands): | |
fun = evaluate(rator, ctx) | |
args = [evaluate(a, ctx) for a in rands] | |
expr, ctx = fun.lazy_apply(args) | |
case ("lambda", params, body): | |
return Lambda(params, body, ctx) | |
case ("tag", name, *args): | |
return (name, *(evaluate(a, ctx) for a in args)) | |
case ("matchfn", *arms): | |
arms = [(Pattern.build(p), b) for p, b in arms] | |
return MatchFn(arms, ctx) | |
case int(x): | |
return x | |
case str(s): | |
return ctx.lookup(s) | |
case _: | |
raise TypeError(expr) | |
@dataclasses.dataclass | |
class Lambda: | |
params: list[str] | |
body: Expr | |
captured_context: Context | |
def lazy_apply(self, args: list[Any]) -> tuple[Expr, Context]: | |
ctx = self.captured_context.extend(dict(zip(self.params, args))) | |
return self.body, ctx | |
@dataclasses.dataclass | |
class Pattern: | |
pat: list[Any] | |
@staticmethod | |
def build(template): | |
return Pattern(template) | |
def matched(self, values: list[Any]): | |
if len(values) != len(self.pat): | |
return False | |
return pattern_match(self.pat, values, {}) | |
def pattern_match(pat, val, env): | |
match pat, val: | |
case "_", _: | |
return env | |
case int(), _: | |
if pat == val: | |
return env | |
else: | |
return False | |
case str(), _: | |
if pat not in env: | |
return env | {pat: val} | |
if env[pat] != val: | |
return False | |
return env | |
case [*pa], [*va] if len(pa) == len(va): | |
env = {} | |
for p, v in zip(pa, va): | |
env = pattern_match(p, v, env) | |
if env is False: | |
return False | |
return env | |
case _: | |
raise NotImplementedError(pat, val) | |
@dataclasses.dataclass | |
class MatchFn: | |
arms: list[tuple[Pattern, Expr]] | |
captured_context: Context | |
def lazy_apply(self, args: list[Any]) -> tuple[Expr, Context]: | |
for pattern, body in self.arms: | |
env = pattern.matched(args) | |
if env is not False: | |
return body, self.captured_context.extend(env) | |
raise ValueError("no pattern matched") | |
print(evaluate(("apply", ("lambda", ["x"], "x"), 42), Context())) | |
print(evaluate(("tag", "Foo", 123), Context())) | |
print(evaluate(("matchfn", ([1], 10), ([2], 20), (["_"], 0)), Context())) | |
print( | |
evaluate( | |
( | |
"apply", | |
("matchfn", ([1], 10), ([2], 20), (["x", ["x", "y", "y", "x"]], ["x", "y"])), | |
5, | |
[5, 1, 1, 5], | |
), | |
Context(), | |
) | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment