Last active
October 7, 2023 23:14
-
-
Save jorendorff/15b248840ef8037feab8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/bin/env python3 | |
| # jouet.py - Toy JS interpreter | |
| import re | |
| from collections import namedtuple | |
| token_re = r'''(?x) | |
| (?: # whitespace | |
| [ \t\n] | |
| | //.* | |
| | /\* (?:[^*]|\*+[^/])* \*+/ | |
| )* | |
| ( # an actual token | |
| [A-Za-z_$][0-9A-Za-z_$]* | |
| | [0-9]+ | |
| | . | |
| ) | |
| ''' | |
| def lex(s): | |
| """ Split a fragment of Jouet code into a list of tokens. """ | |
| return [m.group(1) for m in re.finditer(token_re, s)] | |
| def is_identifier(s): | |
| return s[0:1].isalpha() | |
| def is_number_literal(s): | |
| return s[0:1].isdigit() | |
| class ReturnException(Exception): | |
| def __init__(self, value): | |
| Exception.__init__(self, "return") | |
| self.value = value | |
| Program = namedtuple("Program", "body") | |
| Block = namedtuple("Block", "body") | |
| IfStatement = namedtuple("IfStatement", "test body alt") | |
| WhileStatement = namedtuple("WhileStatement", "test body") | |
| EmptyStatement = namedtuple("EmptyStatement", "") | |
| ExprStatement = namedtuple("ExprStatement", "expr") | |
| ReturnStatement = namedtuple("ReturnStatement", "expr") | |
| FunctionDeclaration = namedtuple("FunctionDeclaration", "code") | |
| CommaExpression = namedtuple("CommaExpression", "left right") | |
| AssignmentExpression = namedtuple("AssignmentExpression", "left right") | |
| BinaryExpression = namedtuple("BinaryExpression", "left op right") | |
| CallExpression = namedtuple("CallExpression", "fn args") | |
| FunctionExpression = namedtuple("FunctionExpression", "code") | |
| FunctionCode = namedtuple("FunctionCode", "name args body") | |
| BoolLiteral = namedtuple("BoolLiteral", "value") | |
| NullLiteral = namedtuple("NullLiteral", "") | |
| NumberLiteral = namedtuple("NumberLiteral", "value") | |
| Identifier = namedtuple("Identifier", "name") | |
| class Parser: | |
| def __init__(self, tokens): | |
| self.tokens = tokens | |
| self.point = 0 | |
| def looking_at(self, token): | |
| return self.point < len(self.tokens) and self.tokens[self.point] == token | |
| def take(self, token): | |
| hit = self.looking_at(token) | |
| if hit: | |
| self.point += 1 | |
| return hit | |
| def take_one_of(self, options): | |
| if self.point < len(self.tokens): | |
| t = self.tokens[self.point] | |
| if t in options: | |
| self.point += 1 | |
| return t | |
| return None | |
| def take_if_match(self, regexp): | |
| if self.point < len(self.tokens): | |
| t = self.tokens[self.point] | |
| if re.match(regexp, t): | |
| self.point += 1 | |
| return t | |
| return None | |
| def insist(self, token): | |
| if not self.take(token): | |
| raise ValueError("expected: %r" % token) | |
| # Parsing methods | |
| def program(self): | |
| """ Parse a whole program. Throw if not all input is consumed. """ | |
| stmts = [] | |
| while self.point < len(self.tokens): | |
| stmts.append(self.stmt()) | |
| return Program(stmts) | |
| def stmt(self): | |
| if self.looking_at("{"): | |
| return self.block() | |
| elif self.take("function"): | |
| return FunctionDeclaration(self.function(name_required=True)) | |
| elif self.take("if"): | |
| self.insist("(") | |
| test = self.expr() | |
| self.insist(")") | |
| body = self.stmt() | |
| if self.take("else"): | |
| alt = self.stmt() | |
| else: | |
| alt = None | |
| return IfStatement(test, body, alt) | |
| elif self.take("while"): | |
| self.insist("(") | |
| test = self.expr() | |
| self.insist(")") | |
| body = self.stmt() | |
| return WhileStatement(test, body) | |
| elif self.take("return"): | |
| expr = self.expr() | |
| self.insist(";") | |
| return ReturnStatement(expr) | |
| elif self.take(";"): | |
| return EmptyStatement() | |
| else: | |
| expr = self.expr() | |
| self.insist(";") | |
| return ExprStatement(expr) | |
| def block(self): | |
| stmts = [] | |
| self.insist("{") | |
| while not self.looking_at("}"): | |
| stmts.append(self.stmt()) | |
| self.insist("}") | |
| return Block(stmts) | |
| def take_identifier(self): | |
| return self.take_if_match(r"^[a-zA-Z_$]") | |
| def function(self, name_required): | |
| name = self.take_identifier() | |
| if name_required and name is None: | |
| raise ValueError("function name expected") | |
| args = [] | |
| self.insist("(") | |
| if not self.looking_at(")"): | |
| while True: | |
| arg_name = self.take_identifier() | |
| if arg_name is None: | |
| raise ValueError("argument expected") | |
| args.append(arg_name) | |
| if not self.take(","): | |
| break | |
| self.insist(")") | |
| body = self.block() | |
| return FunctionCode(name, args, body) | |
| def expr(self): | |
| e1 = self.assignment_expr() | |
| while self.take(","): | |
| e2 = self.assignment_expr() | |
| e1 = CommaExpression(e1, e2) | |
| return e1 | |
| def assignment_expr(self): | |
| e1 = self.addition_expr() | |
| while self.take("="): | |
| e2 = self.addition_expr() | |
| e1 = AssignExpression(e1, e2) | |
| return e1 | |
| def addition_expr(self): | |
| e = self.multiplication_expr() | |
| op = self.take_one_of(["+", "-"]) | |
| while op is not None: | |
| e = BinaryExpression(e, op, self.multiplication_expr()) | |
| op = self.take_one_of(["+", "-"]) | |
| return e | |
| def multiplication_expr(self): | |
| e = self.call_expr() | |
| op = self.take_one_of(["*", "/", "%"]) | |
| while op is not None: | |
| e = BinaryExpression(e, op, self.call_expr()) | |
| op = self.take_one_of(["*", "/", "%"]) | |
| return e | |
| def call_expr(self): | |
| e = self.primary_expr() | |
| while self.take("("): | |
| args = [] | |
| if not self.looking_at(")"): | |
| args.append(self.assignment_expr()) | |
| while self.take(","): | |
| args.append(self.assignment_expr()) | |
| self.insist(")") | |
| e = CallExpression(e, args) | |
| return e | |
| def primary_expr(self): | |
| if self.take("true"): | |
| return BooleanLiteral(True) | |
| elif self.take("false"): | |
| return BooleanLiteral(False) | |
| elif self.take("null"): | |
| return NullLiteral() | |
| elif self.take("("): | |
| e = self.expr() | |
| self.insist(")") | |
| return e | |
| elif self.take("function"): | |
| return FunctionExpression(self.function(name_required=False)) | |
| elif self.point < len(self.tokens): | |
| t = self.tokens[self.point] | |
| if is_identifier(t): | |
| self.point += 1 | |
| return Identifier(t) | |
| elif is_number_literal(t): | |
| self.point += 1 | |
| return NumberLiteral(int(t)) | |
| else: | |
| raise ValueError("unexpected end of input") | |
| def parse(s): | |
| p = Parser(lex(s)) | |
| return p.program() | |
| def toBoolean(v): | |
| if v in (None, False, 0): | |
| return False | |
| return True | |
| class Env: | |
| def __init__(self, outer): | |
| self.outer = outer | |
| self.bindings = {} | |
| def lookup(self, name): | |
| env = self | |
| while env is not None: | |
| if name in env.bindings: | |
| return env.bindings[name] | |
| env = env.outer | |
| raise ValueError("%r is undefined" % name) | |
| class UndefinedClass: | |
| pass | |
| Undefined = UndefinedClass() | |
| Code = namedtuple("Code", "name args body") | |
| Fn = namedtuple("Fn", "code env") | |
| def apply_fn(env, fn, actual_args): | |
| fenv = Env(env) | |
| for name, value in zip(fn.code.args, actual_args): | |
| fenv.bindings[name] = value | |
| if "arguments" not in fenv.bindings: | |
| fenv.bindings["arguments"] = actual_args | |
| try: | |
| evaluate(fenv, fn.code.body) | |
| except ReturnException as exc: | |
| return exc.value | |
| return Undefined | |
| def evaluate(env, ast): | |
| ast_type = ast.__class__ | |
| if ast_type in (Block, Program): | |
| for stmt in ast.body: | |
| evaluate(env, stmt) | |
| elif ast_type is FunctionDeclaration: | |
| env.bindings[ast.code.name] = Fn(ast.code, env) | |
| elif ast_type is IfStatement: | |
| if toBoolean(evaluate(env, ast.test)): | |
| evaluate(env, ast.body) | |
| elif ast.alt is not None: | |
| evaluate(env, ast.alt) | |
| else: | |
| return None | |
| elif ast_type is WhileStatement: | |
| while toBoolean(evaluate(env, ast.test)): | |
| evaluate(env, ast.body) | |
| elif ast_type is ExprStatement: | |
| evaluate(env, ast.expr) | |
| elif ast_type is ReturnStatement: | |
| raise ReturnException(evaluate(env, ast.expr)) | |
| elif ast_type is CommaExpression: | |
| evaluate(env, ast.left) | |
| return evaluate(env, ast.right) | |
| elif ast_type is AssignmentExpression: | |
| evaluate(env, ast.left) | |
| evaluate(env, ast.right) | |
| raise ValueError("not implemented: assignment") | |
| elif ast_type is BinaryExpression: | |
| left = evaluate(env, ast.left) | |
| op = ast.op | |
| right = evaluate(env, ast.right) | |
| if op == "+": | |
| return left + right | |
| elif op == "-": | |
| return left - right | |
| elif op == "*": | |
| return left * right | |
| elif op == "%": | |
| return left % right | |
| elif op == "/": | |
| return left / right | |
| raise ValueError("internal error") | |
| elif ast_type is CallExpression: | |
| fn = evaluate(env, ast.fn) | |
| args = [evaluate(env, arg) for arg in ast.args] | |
| if isinstance(fn, Fn): | |
| return apply_fn(env, fn, args) | |
| return fn(*args) | |
| elif ast_type is FunctionExpression: | |
| return Fn(env, ast.code) | |
| elif ast_type in (BoolLiteral, NumberLiteral): | |
| return ast.value | |
| elif ast_type is NullLiteral: | |
| return None | |
| elif ast_type is Identifier: | |
| return env.lookup(ast.name) | |
| else: | |
| raise ValueError("unexpected node type %s" % ast_type.__name__) | |
| def test(): | |
| hit = False | |
| def jouet_assert_eq(a, b): | |
| nonlocal hit | |
| if a != b: | |
| raise ValueError("Assertion failed: got %r, expected %r (%r)", a, b, code) | |
| hit = True | |
| def run_test(code): | |
| nonlocal hit | |
| hit = False | |
| program = parse(code) | |
| env = Env(None) | |
| env.bindings["print"] = print | |
| env.bindings["assertEq"] = jouet_assert_eq | |
| evaluate(env, program) | |
| if not hit: | |
| raise ValueError("Test failed: assertion not hit (%r)", code) | |
| run_test("if (3*0) assertEq(1, 2); else assertEq(2, 2);") | |
| run_test("function f(x) { assertEq(x - 1, x - 1); } f(33);") | |
| run_test("assertEq((1, 2, 3), 3);") | |
| run_test("function f(x) { return x + 1; } assertEq(f(3), 4);") | |
| run_test("function fac(x) { if (x) return x * fac(x - 1); return 1; } assertEq(fac(5), 120);") | |
| print("all tests passed") | |
| test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment