Created
August 6, 2024 20:39
-
-
Save eliasdorneles/498efbb2b5e1007bd2db20baeac591d7 to your computer and use it in GitHub Desktop.
Simple calculator demonstrating Pratt parsing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
from enum import StrEnum | |
class TokenType(StrEnum): | |
PLUS = "PLUS" | |
MINUS = "MINUS" | |
MUL = "MUL" | |
DIV = "DIV" | |
LPAREN = "LPAREN" | |
RPAREN = "RPAREN" | |
NUMBER = "NUMBER" | |
IDENTIFIER = "IDENTIFIER" | |
ASSIGNMENT = "ASSIGNMENT" | |
class Token: | |
def __init__(self, value): | |
if not value: | |
raise ValueError("Token value cannot be empty") | |
self.value = value | |
self.type = self._get_type() | |
def _get_type(self): | |
operators = { | |
"+": TokenType.PLUS, | |
"-": TokenType.MINUS, | |
"*": TokenType.MUL, | |
"/": TokenType.DIV, | |
"(": TokenType.LPAREN, | |
")": TokenType.RPAREN, | |
"=": TokenType.ASSIGNMENT, | |
} | |
if self.value in operators: | |
return operators[self.value] | |
if self.value.isdigit(): | |
return TokenType.NUMBER | |
if self.value[0].isalpha(): | |
return TokenType.IDENTIFIER | |
raise ValueError(f"Unknown token type for {self.value}") | |
def __repr__(self): | |
return f"Token(value={self.value}, type={self.type})" | |
def tokenize(text): | |
return [Token(x) for x in re.split(r"(\s+|[-+*/=()])", text) if x.strip()] | |
class PrattParser: | |
""" | |
Parser implementing top-down operator precedence parsing, also known as a | |
Pratt parsing. | |
""" | |
tokens: list | |
def __init__(self): | |
self.tokens = [] | |
self.prefix_parselets = {} | |
self.infix_parselets = {} | |
def parse(self, input_text): | |
self.tokens = tokenize(input_text) | |
return self.parse_next() | |
def parse_next(self, precedence=0): | |
if not self.tokens: | |
raise ValueError("Unexpected end of input") | |
cur_token = self.tokens.pop(0) | |
prefix_parser = self.prefix_parselets.get(cur_token.type) | |
if not prefix_parser: | |
raise ValueError(f"Could not parse {cur_token!r} (prefix)") | |
# prefix parselets signature: (parser, token): | |
left = prefix_parser(self, cur_token) | |
while precedence < self.get_precedence(): | |
cur_token = self.tokens.pop(0) | |
infix_parser = self.infix_parselets.get(cur_token.type) | |
if not infix_parser: | |
raise ValueError(f"Could not parse {cur_token!r} (infix)") | |
# infix parselets signature: (parser, left_expr, token): | |
left = infix_parser(self, left, cur_token) | |
return left | |
def consume(self, token_type): | |
if not self.tokens: | |
raise ValueError(f"Expected {token_type!r} but got nothing") | |
token = self.tokens.pop(0) | |
if token.type != token_type: | |
raise ValueError(f"Expected {token_type!r} but got {token!r}") | |
return token | |
def get_precedence(self): | |
if not self.tokens: | |
return 0 | |
infix_parser = self.infix_parselets.get(self.tokens[0].type) | |
return infix_parser.precedence if infix_parser else 0 | |
# Prefix Parselets: | |
def parse_scalar(_parser, token): | |
return (token.type.value, token.value) | |
def build_prefix_op_parselet(precedence): | |
def parse_operator(parser, token): | |
right = parser.parse_next(precedence) | |
return ("UNOP_" + token.type.value, right) | |
return parse_operator | |
def parse_group(parser, _token): | |
expr = parser.parse_next() | |
parser.consume(TokenType.RPAREN) | |
return expr | |
# Infix Parselets: | |
class BinaryOperatorParselet: | |
def __init__(self, precedence): | |
self.precedence = precedence | |
def __call__(self, parser, left, token): | |
right = parser.parse_next(self.precedence) | |
return ("OP_" + token.type.value, left, right) | |
class AssignmentParselet: | |
def __init__(self, precedence): | |
self.precedence = precedence | |
def __call__(self, parser, left, _token): | |
right = parser.parse_next(self.precedence - 1) | |
# here, we expect left to be a parsed identifier ("IDENTIFIER", "name") | |
if left[0] != "IDENTIFIER": | |
raise ValueError(f"Expected identifier on the left side of assignment") | |
return ("ASSIGN", left, right) | |
class CalcParser(PrattParser): | |
def __init__(self): | |
super().__init__() | |
self.prefix_parselets[TokenType.IDENTIFIER] = parse_scalar | |
self.prefix_parselets[TokenType.NUMBER] = parse_scalar | |
self.prefix_parselets[TokenType.MINUS] = build_prefix_op_parselet(10) | |
self.prefix_parselets[TokenType.PLUS] = build_prefix_op_parselet(10) | |
self.prefix_parselets[TokenType.LPAREN] = parse_group | |
self.infix_parselets[TokenType.ASSIGNMENT] = AssignmentParselet(1) | |
self.infix_parselets[TokenType.PLUS] = BinaryOperatorParselet(5) | |
self.infix_parselets[TokenType.MINUS] = BinaryOperatorParselet(5) | |
self.infix_parselets[TokenType.MUL] = BinaryOperatorParselet(7) | |
self.infix_parselets[TokenType.DIV] = BinaryOperatorParselet(7) | |
class Calculator: | |
def __init__(self): | |
self.variables = {} | |
self.parser = CalcParser() | |
def evaluate(self, expr): | |
parsed = self.parser.parse(expr) | |
return self._evaluate_expr(parsed) | |
def _evaluate_expr(self, parsed_expr): | |
op, *args = parsed_expr | |
if op.startswith("OP_"): | |
return self._evaluate_op(op[3:], args) | |
if op == "ASSIGN": | |
return self._evaluate_assign(args) | |
if op.startswith("UNOP_"): | |
return self._evaluate_unop(op[5:], args[0]) | |
return self._evaluate_scalar(parsed_expr) | |
def _evaluate_scalar(self, scalar): | |
if scalar[0] == "NUMBER": | |
return int(scalar[1]) | |
if scalar[0] == "IDENTIFIER": | |
if scalar[1] not in self.variables: | |
raise ValueError(f"Unknown variable: '{scalar[1]}'") | |
return self.variables[scalar[1]] | |
raise ValueError(f"Unknown scalar type: {scalar}") | |
def _evaluate_assign(self, args): | |
_, name = args[0] | |
value = self._evaluate_expr(args[1]) | |
self.variables[name] = value | |
return value | |
def _evaluate_op(self, op, args): | |
left = self._evaluate_expr(args[0]) | |
right = self._evaluate_expr(args[1]) | |
if op == "PLUS": | |
return left + right | |
if op == "MINUS": | |
return left - right | |
if op == "MUL": | |
return left * right | |
if op == "DIV": | |
return left / right | |
raise ValueError(f"Unknown operator: {op}") | |
def _evaluate_unop(self, op, arg): | |
if op == "PLUS": | |
return arg | |
if op == "MINUS": | |
return -arg | |
raise ValueError(f"Unknown unary operator: {op}") | |
if __name__ == "__main__": | |
parser = CalcParser() | |
calc = Calculator() | |
while True: | |
try: | |
line = input("pratt-calc % ") | |
if not line: | |
continue | |
print(calc.evaluate(line)) | |
except (KeyboardInterrupt, EOFError): | |
print("Bye") | |
break |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment