Skip to content

Instantly share code, notes, and snippets.

@LeeeeT
Last active March 2, 2025 17:05
Show Gist options
  • Save LeeeeT/4a2a0b88caaf8cde11766da95de5dad7 to your computer and use it in GitHub Desktop.
Save LeeeeT/4a2a0b88caaf8cde11766da95de5dad7 to your computer and use it in GitHub Desktop.
Clean (in Uncle Bob's voice) implementation of λ-calculus (doesn't work)
from dataclasses import dataclass
program = r"""
(define ->
define (nil -> cons ->
nil
)(nil ->
define (head -> tail -> nil -> cons ->
cons head (tail nil cons)
)(cons ->
define (item -> list ->
list (cons item nil) cons
)(append ->
define (list ->
list nil append
)(reverse ->
reverse (cons 1 (cons 2 (cons 3 nil)))
)))))
(definition -> scope -> scope definition)
"""
def main() -> None:
term = Parser(program).parse()
print(show(nf(term)))
type Term = Identifier | Abstraction | Application
@dataclass(frozen=True)
class Identifier:
identifier: str
@dataclass(frozen=True)
class Abstraction:
binder: str
body: Term
@dataclass(frozen=True)
class Application:
function: Term
argument: Term
def nf(term: Term) -> Term:
match whnf(term):
case Identifier(identifier):
return Identifier(identifier)
case Abstraction(binder, body):
return Abstraction(binder, nf(body))
case Application(function, argument):
return Application(nf(function), nf(argument))
def whnf(term: Term) -> Term:
match term:
case Application(function, argument):
match whnf(function):
case Abstraction(binder, body):
return whnf(substitute(body, binder, argument))
case _:
return term
case _:
return term
def substitute(term: Term, name: str, new: Term) -> Term:
match term:
case Identifier(identifier):
return new if identifier == name else term
case Abstraction(binder, body):
return Abstraction(binder, substitute(body, name, new)) if binder != name else term
case Application(function, argument):
return Application(substitute(function, name, new), substitute(argument, name, new))
def show(term: Term) -> str:
match term:
case Identifier(identifier):
return identifier
case Abstraction(binder, body):
return f"{binder} -> {show(body)}"
case Application(function, argument):
function = f"({show(function)})" if isinstance(function, Abstraction) else show(function)
argument = f"({show(argument)})" if not isinstance(argument, Identifier) else show(argument)
return f"{function} {argument}"
class Parser:
def __init__(self, input: str):
self.input = input
self.position = 0
def skip_spaces(self) -> None:
while self.position < len(self.input) and self.input[self.position].isspace():
self.position += 1
def consume(self, character: str) -> None:
self.skip_spaces()
if self.position < len(self.input) and self.input[self.position] == character:
self.position += 1
else:
got = "end of input" if self.position >= len(self.input) else self.input[self.position]
raise ValueError(f"Expected '{character}', got '{got}'")
def is_identifier_character(self, character: str) -> bool:
return not character.isspace() and character not in {"(", ")", "\\", "-", ">"}
def parse_identifier(self) -> str:
self.skip_spaces()
start = self.position
while self.position < len(self.input) and self.is_identifier_character(self.input[self.position]):
self.position += 1
if start == self.position:
raise ValueError(f"Expected identifier at position {self.position}")
return self.input[start:self.position]
def parse_factor(self) -> Term:
self.skip_spaces()
if self.position >= len(self.input):
raise ValueError("Unexpected end of input while expecting a term")
if self.input[self.position] == "(":
self.consume("(")
term = self.parse_term()
self.consume(")")
return term
elif self.is_identifier_character(self.input[self.position]):
binder = self.parse_identifier()
self.skip_spaces()
if self.input.startswith("->", self.position):
self.position += 2
body = self.parse_term()
return Abstraction(binder, body)
else:
return Identifier(binder)
else:
raise ValueError(f"Unexpected character '{self.input[self.position]}' at position {self.position}")
def parse_term(self) -> Term:
self.skip_spaces()
term = self.parse_factor()
self.skip_spaces()
while self.position < len(self.input) and (self.input[self.position] == "(" or self.is_identifier_character(self.input[self.position])):
argument = self.parse_factor()
term = Application(term, argument)
self.skip_spaces()
return term
def parse(self) -> Term:
result = self.parse_term()
self.skip_spaces()
if self.position != len(self.input):
raise ValueError("Unexpected characters at end of input")
return result
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment