Skip to content

Instantly share code, notes, and snippets.

@vastus
Created March 21, 2025 20:25
Show Gist options
  • Save vastus/08eb085d5856840521de515c0b6e8c66 to your computer and use it in GitHub Desktop.
Save vastus/08eb085d5856840521de515c0b6e8c66 to your computer and use it in GitHub Desktop.
from enum import Enum
from typing import Any, Optional
class TokenType(Enum):
NUMBER = 1
BANG = 2
CARET = 3
STAR = 4
SLASH = 5
PLUS = 6
MINUS = 7
LPAREN = 8
RPAREN = 9
class Token:
def __init__(self, typ: TokenType, value: Any):
self.typ = typ
self.value = value
def __repr__(self):
return f'Token({self.typ.name, self.value})'
class Lexer:
def __init__(self, code: str):
self.code = code
self.loc = 0
self.line = 1
self.col = 1
def lex(self):
tokens = []
while self.has_next():
c = self.peek()
if c.isspace():
self.advance()
elif c.isdigit():
number = self.number()
tokens.append(number)
elif c.isascii():
operator = self.operator()
tokens.append(operator)
else:
raise Exception(f'{c}: unknown char ({self.line}, {self.col})')
return tokens
def peek(self):
return self.code[self.loc]
def advance(self):
if self.peek() == '\n':
self.line += 1
self.col = 1
else:
self.col += 1
self.loc += 1
def eat(self):
char = self.peek()
self.advance()
return char
def has_next(self):
return self.loc < len(self.code)
def number(self):
if self.peek() == '0':
self.advance()
assert not self.has_next() or not self.peek().isdigit()
return 0
x = []
while self.has_next() and self.peek().isdigit():
x.append(self.eat())
num = int(''.join(x))
return Token(TokenType.NUMBER, num)
def operator(self):
op = self.eat()
typ = None
if op == '!':
typ = TokenType.BANG
elif op == '^':
typ = TokenType.CARET
elif op == '*':
typ = TokenType.STAR
elif op == '/':
typ = TokenType.SLASH
elif op == '+':
typ = TokenType.PLUS
elif op == '-':
typ = TokenType.MINUS
elif op == '(':
typ = TokenType.LPAREN
elif op == ')':
typ = TokenType.RPAREN
else:
raise Exception(f'{op}: unknown op ({self.line},{self.col})')
return Token(typ, op)
class Expr:
def __init__(self):
pass
def accept(self, visitor):
raise NotImplementedError(visitor)
class Binary(Expr):
def __init__(self, lhs, op: Optional[Token] = None, rhs: Optional[Expr] = None):
self.lhs = lhs
self.op = op
self.rhs = rhs
def accept(self, visitor):
return visitor.visit_binary_expr(self)
class Unary(Expr):
def __init__(self, op: Token, rhs: Optional[Expr] = None):
self.op = op
self.rhs = rhs
def accept(self, visitor):
return visitor.visit_unary_expr(self)
class Literal(Expr):
def __init__(self, value):
self.value = value
def accept(self, visitor):
return visitor.visit_literal_expr(self)
class Grouping(Expr):
def __init__(self, expr: Expr):
self.expression = expr
def accept(self, visitor):
return visitor.visit_grouping_expr(self)
# Grammar
#
# expr -> term
# term -> factor ( ( '+' | '-' ) factor )*
# factor -> power ( ( '*' | '/' ) power )*
# power -> unary ( '^' unary )*
# unary -> ( '!' | '-' ) unary
# | primary
# primary -> NUMBER | '(' expr ')'
class Parser:
def __init__(self, code: str):
lexer = Lexer(code)
self.tokens = lexer.lex()
self.loc = 0
def parse(self):
return self.expression()
# expr -> term
def expression(self):
return self.term()
# term -> factor ( ( '+' | '-' ) factor )*
def term(self):
expr = self.factor()
while self.match([TokenType.PLUS, TokenType.MINUS]):
op = self.eat()
rhs = self.factor()
# print('binary op', op, 'rhs', rhs.value)
expr = Binary(expr, op, rhs)
return expr
# factor -> power ( ( '*' | '/' ) power )*
def factor(self):
expr = self.power()
while self.match([TokenType.STAR, TokenType.SLASH]):
op = self.eat()
rhs = self.power()
expr = Binary(expr, op, rhs)
return expr
# power -> unary ( '^' unary )*
def power(self):
expr = self.unary()
while self.match([TokenType.CARET]):
op = self.eat()
rhs = self.unary()
expr = Binary(expr, op, rhs)
return expr
# unary -> ( '!' | '-' ) unary | primary
def unary(self):
if self.match([TokenType.BANG, TokenType.MINUS]):
op = self.eat()
rhs = self.primary()
return Unary(op, rhs)
return self.primary()
# primary -> NUMBER | '(' expr ')'
def primary(self):
if self.match([TokenType.LPAREN]):
self.advance()
expr = self.expression()
self.consume(TokenType.RPAREN)
return Grouping(expr)
number_token = self.consume(TokenType.NUMBER)
return Literal(number_token.value)
# parser helpers
def has_next(self):
return self.loc < len(self.tokens)
def advance(self):
self.loc += 1
def peek(self) -> Token:
return self.tokens[self.loc]
def match(self, types):
if not self.has_next():
return None
if self.peek().typ in types:
return self.peek()
def eat(self):
token = self.peek()
self.advance()
return token
def consume(self, typ):
if not self.match([typ]):
raise Exception(f'expected {typ} got {self.peek()}')
return self.eat()
class PrinterVisitor:
def visit_unary_expr(self, expr: Unary):
assert expr.rhs
rhs = expr.rhs.accept(self)
repr = f'({expr.op} {rhs})'
return repr
def visit_binary_expr(self, expr: Binary):
assert expr.op and expr.lhs and expr.rhs
lhs = expr.lhs.accept(self)
rhs = expr.rhs.accept(self)
repr = f'({expr.op} {lhs} {rhs})'
return repr
def visit_literal_expr(self, expr: Literal):
return f'{expr.value}'
def visit_grouping_expr(self, expr: Grouping):
grp = expr.expression.accept(self)
return grp
class InterpretVisitor:
def visit_unary_expr(self, expr: Unary):
assert expr.op and expr.rhs
rhs = self.eval(expr.rhs)
print('rhs', rhs, type(rhs))
if expr.op.typ == TokenType.BANG:
return not rhs
return -rhs
def visit_binary_expr(self, expr: Binary):
assert expr.op and expr.lhs and expr.rhs
lhs = self.eval(expr.lhs)
rhs = self.eval(expr.rhs)
if expr.op.typ == TokenType.PLUS:
return lhs + rhs
if expr.op.typ == TokenType.MINUS:
return lhs - rhs
if expr.op.typ == TokenType.STAR:
return lhs * rhs
if expr.op.typ == TokenType.SLASH:
return lhs / rhs
if expr.op.typ == TokenType.CARET:
return lhs ** rhs
raise Exception(f'{expr.op}: unknown op')
def visit_literal_expr(self, expr: Literal):
return expr.value
def visit_grouping_expr(self, expr: Grouping):
return self.eval(expr.expression)
def eval(self, expr: Expr):
return expr.accept(self)
def main():
# code = '3 + (2 - 1)'
# code = '-1 + 2 '
# code = '-273'
# code = '2*2^2^2'
code = '3 * 7 + 8 ^ (1 + 1)'
parser = Parser(code)
ast = parser.parse()
repr = ast.accept(PrinterVisitor())
print('repr', repr)
res = ast.accept(InterpretVisitor())
print('res', res)
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment