Created
March 21, 2025 20:25
-
-
Save vastus/08eb085d5856840521de515c0b6e8c66 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from enum import Enum | |
from typing import Any, Optional | |
class TokenType(Enum): | |
NUMBER = 1 | |
BANG = 2 | |
CARET = 3 | |
STAR = 4 | |
SLASH = 5 | |
PLUS = 6 | |
MINUS = 7 | |
LPAREN = 8 | |
RPAREN = 9 | |
class Token: | |
def __init__(self, typ: TokenType, value: Any): | |
self.typ = typ | |
self.value = value | |
def __repr__(self): | |
return f'Token({self.typ.name, self.value})' | |
class Lexer: | |
def __init__(self, code: str): | |
self.code = code | |
self.loc = 0 | |
self.line = 1 | |
self.col = 1 | |
def lex(self): | |
tokens = [] | |
while self.has_next(): | |
c = self.peek() | |
if c.isspace(): | |
self.advance() | |
elif c.isdigit(): | |
number = self.number() | |
tokens.append(number) | |
elif c.isascii(): | |
operator = self.operator() | |
tokens.append(operator) | |
else: | |
raise Exception(f'{c}: unknown char ({self.line}, {self.col})') | |
return tokens | |
def peek(self): | |
return self.code[self.loc] | |
def advance(self): | |
if self.peek() == '\n': | |
self.line += 1 | |
self.col = 1 | |
else: | |
self.col += 1 | |
self.loc += 1 | |
def eat(self): | |
char = self.peek() | |
self.advance() | |
return char | |
def has_next(self): | |
return self.loc < len(self.code) | |
def number(self): | |
if self.peek() == '0': | |
self.advance() | |
assert not self.has_next() or not self.peek().isdigit() | |
return 0 | |
x = [] | |
while self.has_next() and self.peek().isdigit(): | |
x.append(self.eat()) | |
num = int(''.join(x)) | |
return Token(TokenType.NUMBER, num) | |
def operator(self): | |
op = self.eat() | |
typ = None | |
if op == '!': | |
typ = TokenType.BANG | |
elif op == '^': | |
typ = TokenType.CARET | |
elif op == '*': | |
typ = TokenType.STAR | |
elif op == '/': | |
typ = TokenType.SLASH | |
elif op == '+': | |
typ = TokenType.PLUS | |
elif op == '-': | |
typ = TokenType.MINUS | |
elif op == '(': | |
typ = TokenType.LPAREN | |
elif op == ')': | |
typ = TokenType.RPAREN | |
else: | |
raise Exception(f'{op}: unknown op ({self.line},{self.col})') | |
return Token(typ, op) | |
class Expr: | |
def __init__(self): | |
pass | |
def accept(self, visitor): | |
raise NotImplementedError(visitor) | |
class Binary(Expr): | |
def __init__(self, lhs, op: Optional[Token] = None, rhs: Optional[Expr] = None): | |
self.lhs = lhs | |
self.op = op | |
self.rhs = rhs | |
def accept(self, visitor): | |
return visitor.visit_binary_expr(self) | |
class Unary(Expr): | |
def __init__(self, op: Token, rhs: Optional[Expr] = None): | |
self.op = op | |
self.rhs = rhs | |
def accept(self, visitor): | |
return visitor.visit_unary_expr(self) | |
class Literal(Expr): | |
def __init__(self, value): | |
self.value = value | |
def accept(self, visitor): | |
return visitor.visit_literal_expr(self) | |
class Grouping(Expr): | |
def __init__(self, expr: Expr): | |
self.expression = expr | |
def accept(self, visitor): | |
return visitor.visit_grouping_expr(self) | |
# Grammar | |
# | |
# expr -> term | |
# term -> factor ( ( '+' | '-' ) factor )* | |
# factor -> power ( ( '*' | '/' ) power )* | |
# power -> unary ( '^' unary )* | |
# unary -> ( '!' | '-' ) unary | |
# | primary | |
# primary -> NUMBER | '(' expr ')' | |
class Parser: | |
def __init__(self, code: str): | |
lexer = Lexer(code) | |
self.tokens = lexer.lex() | |
self.loc = 0 | |
def parse(self): | |
return self.expression() | |
# expr -> term | |
def expression(self): | |
return self.term() | |
# term -> factor ( ( '+' | '-' ) factor )* | |
def term(self): | |
expr = self.factor() | |
while self.match([TokenType.PLUS, TokenType.MINUS]): | |
op = self.eat() | |
rhs = self.factor() | |
# print('binary op', op, 'rhs', rhs.value) | |
expr = Binary(expr, op, rhs) | |
return expr | |
# factor -> power ( ( '*' | '/' ) power )* | |
def factor(self): | |
expr = self.power() | |
while self.match([TokenType.STAR, TokenType.SLASH]): | |
op = self.eat() | |
rhs = self.power() | |
expr = Binary(expr, op, rhs) | |
return expr | |
# power -> unary ( '^' unary )* | |
def power(self): | |
expr = self.unary() | |
while self.match([TokenType.CARET]): | |
op = self.eat() | |
rhs = self.unary() | |
expr = Binary(expr, op, rhs) | |
return expr | |
# unary -> ( '!' | '-' ) unary | primary | |
def unary(self): | |
if self.match([TokenType.BANG, TokenType.MINUS]): | |
op = self.eat() | |
rhs = self.primary() | |
return Unary(op, rhs) | |
return self.primary() | |
# primary -> NUMBER | '(' expr ')' | |
def primary(self): | |
if self.match([TokenType.LPAREN]): | |
self.advance() | |
expr = self.expression() | |
self.consume(TokenType.RPAREN) | |
return Grouping(expr) | |
number_token = self.consume(TokenType.NUMBER) | |
return Literal(number_token.value) | |
# parser helpers | |
def has_next(self): | |
return self.loc < len(self.tokens) | |
def advance(self): | |
self.loc += 1 | |
def peek(self) -> Token: | |
return self.tokens[self.loc] | |
def match(self, types): | |
if not self.has_next(): | |
return None | |
if self.peek().typ in types: | |
return self.peek() | |
def eat(self): | |
token = self.peek() | |
self.advance() | |
return token | |
def consume(self, typ): | |
if not self.match([typ]): | |
raise Exception(f'expected {typ} got {self.peek()}') | |
return self.eat() | |
class PrinterVisitor: | |
def visit_unary_expr(self, expr: Unary): | |
assert expr.rhs | |
rhs = expr.rhs.accept(self) | |
repr = f'({expr.op} {rhs})' | |
return repr | |
def visit_binary_expr(self, expr: Binary): | |
assert expr.op and expr.lhs and expr.rhs | |
lhs = expr.lhs.accept(self) | |
rhs = expr.rhs.accept(self) | |
repr = f'({expr.op} {lhs} {rhs})' | |
return repr | |
def visit_literal_expr(self, expr: Literal): | |
return f'{expr.value}' | |
def visit_grouping_expr(self, expr: Grouping): | |
grp = expr.expression.accept(self) | |
return grp | |
class InterpretVisitor: | |
def visit_unary_expr(self, expr: Unary): | |
assert expr.op and expr.rhs | |
rhs = self.eval(expr.rhs) | |
print('rhs', rhs, type(rhs)) | |
if expr.op.typ == TokenType.BANG: | |
return not rhs | |
return -rhs | |
def visit_binary_expr(self, expr: Binary): | |
assert expr.op and expr.lhs and expr.rhs | |
lhs = self.eval(expr.lhs) | |
rhs = self.eval(expr.rhs) | |
if expr.op.typ == TokenType.PLUS: | |
return lhs + rhs | |
if expr.op.typ == TokenType.MINUS: | |
return lhs - rhs | |
if expr.op.typ == TokenType.STAR: | |
return lhs * rhs | |
if expr.op.typ == TokenType.SLASH: | |
return lhs / rhs | |
if expr.op.typ == TokenType.CARET: | |
return lhs ** rhs | |
raise Exception(f'{expr.op}: unknown op') | |
def visit_literal_expr(self, expr: Literal): | |
return expr.value | |
def visit_grouping_expr(self, expr: Grouping): | |
return self.eval(expr.expression) | |
def eval(self, expr: Expr): | |
return expr.accept(self) | |
def main(): | |
# code = '3 + (2 - 1)' | |
# code = '-1 + 2 ' | |
# code = '-273' | |
# code = '2*2^2^2' | |
code = '3 * 7 + 8 ^ (1 + 1)' | |
parser = Parser(code) | |
ast = parser.parse() | |
repr = ast.accept(PrinterVisitor()) | |
print('repr', repr) | |
res = ast.accept(InterpretVisitor()) | |
print('res', res) | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment