Skip to content

Instantly share code, notes, and snippets.

@jorendorff
Last active October 7, 2023 23:14
Show Gist options
  • Select an option

  • Save jorendorff/15b248840ef8037feab8 to your computer and use it in GitHub Desktop.

Select an option

Save jorendorff/15b248840ef8037feab8 to your computer and use it in GitHub Desktop.
#!/bin/env python3
# jouet.py - Toy JS interpreter
import re
from collections import namedtuple
token_re = r'''(?x)
(?: # whitespace
[ \t\n]
| //.*
| /\* (?:[^*]|\*+[^/])* \*+/
)*
( # an actual token
[A-Za-z_$][0-9A-Za-z_$]*
| [0-9]+
| .
)
'''
def lex(s):
""" Split a fragment of Jouet code into a list of tokens. """
return [m.group(1) for m in re.finditer(token_re, s)]
def is_identifier(s):
return s[0:1].isalpha()
def is_number_literal(s):
return s[0:1].isdigit()
class ReturnException(Exception):
def __init__(self, value):
Exception.__init__(self, "return")
self.value = value
Program = namedtuple("Program", "body")
Block = namedtuple("Block", "body")
IfStatement = namedtuple("IfStatement", "test body alt")
WhileStatement = namedtuple("WhileStatement", "test body")
EmptyStatement = namedtuple("EmptyStatement", "")
ExprStatement = namedtuple("ExprStatement", "expr")
ReturnStatement = namedtuple("ReturnStatement", "expr")
FunctionDeclaration = namedtuple("FunctionDeclaration", "code")
CommaExpression = namedtuple("CommaExpression", "left right")
AssignmentExpression = namedtuple("AssignmentExpression", "left right")
BinaryExpression = namedtuple("BinaryExpression", "left op right")
CallExpression = namedtuple("CallExpression", "fn args")
FunctionExpression = namedtuple("FunctionExpression", "code")
FunctionCode = namedtuple("FunctionCode", "name args body")
BoolLiteral = namedtuple("BoolLiteral", "value")
NullLiteral = namedtuple("NullLiteral", "")
NumberLiteral = namedtuple("NumberLiteral", "value")
Identifier = namedtuple("Identifier", "name")
class Parser:
def __init__(self, tokens):
self.tokens = tokens
self.point = 0
def looking_at(self, token):
return self.point < len(self.tokens) and self.tokens[self.point] == token
def take(self, token):
hit = self.looking_at(token)
if hit:
self.point += 1
return hit
def take_one_of(self, options):
if self.point < len(self.tokens):
t = self.tokens[self.point]
if t in options:
self.point += 1
return t
return None
def take_if_match(self, regexp):
if self.point < len(self.tokens):
t = self.tokens[self.point]
if re.match(regexp, t):
self.point += 1
return t
return None
def insist(self, token):
if not self.take(token):
raise ValueError("expected: %r" % token)
# Parsing methods
def program(self):
""" Parse a whole program. Throw if not all input is consumed. """
stmts = []
while self.point < len(self.tokens):
stmts.append(self.stmt())
return Program(stmts)
def stmt(self):
if self.looking_at("{"):
return self.block()
elif self.take("function"):
return FunctionDeclaration(self.function(name_required=True))
elif self.take("if"):
self.insist("(")
test = self.expr()
self.insist(")")
body = self.stmt()
if self.take("else"):
alt = self.stmt()
else:
alt = None
return IfStatement(test, body, alt)
elif self.take("while"):
self.insist("(")
test = self.expr()
self.insist(")")
body = self.stmt()
return WhileStatement(test, body)
elif self.take("return"):
expr = self.expr()
self.insist(";")
return ReturnStatement(expr)
elif self.take(";"):
return EmptyStatement()
else:
expr = self.expr()
self.insist(";")
return ExprStatement(expr)
def block(self):
stmts = []
self.insist("{")
while not self.looking_at("}"):
stmts.append(self.stmt())
self.insist("}")
return Block(stmts)
def take_identifier(self):
return self.take_if_match(r"^[a-zA-Z_$]")
def function(self, name_required):
name = self.take_identifier()
if name_required and name is None:
raise ValueError("function name expected")
args = []
self.insist("(")
if not self.looking_at(")"):
while True:
arg_name = self.take_identifier()
if arg_name is None:
raise ValueError("argument expected")
args.append(arg_name)
if not self.take(","):
break
self.insist(")")
body = self.block()
return FunctionCode(name, args, body)
def expr(self):
e1 = self.assignment_expr()
while self.take(","):
e2 = self.assignment_expr()
e1 = CommaExpression(e1, e2)
return e1
def assignment_expr(self):
e1 = self.addition_expr()
while self.take("="):
e2 = self.addition_expr()
e1 = AssignExpression(e1, e2)
return e1
def addition_expr(self):
e = self.multiplication_expr()
op = self.take_one_of(["+", "-"])
while op is not None:
e = BinaryExpression(e, op, self.multiplication_expr())
op = self.take_one_of(["+", "-"])
return e
def multiplication_expr(self):
e = self.call_expr()
op = self.take_one_of(["*", "/", "%"])
while op is not None:
e = BinaryExpression(e, op, self.call_expr())
op = self.take_one_of(["*", "/", "%"])
return e
def call_expr(self):
e = self.primary_expr()
while self.take("("):
args = []
if not self.looking_at(")"):
args.append(self.assignment_expr())
while self.take(","):
args.append(self.assignment_expr())
self.insist(")")
e = CallExpression(e, args)
return e
def primary_expr(self):
if self.take("true"):
return BooleanLiteral(True)
elif self.take("false"):
return BooleanLiteral(False)
elif self.take("null"):
return NullLiteral()
elif self.take("("):
e = self.expr()
self.insist(")")
return e
elif self.take("function"):
return FunctionExpression(self.function(name_required=False))
elif self.point < len(self.tokens):
t = self.tokens[self.point]
if is_identifier(t):
self.point += 1
return Identifier(t)
elif is_number_literal(t):
self.point += 1
return NumberLiteral(int(t))
else:
raise ValueError("unexpected end of input")
def parse(s):
p = Parser(lex(s))
return p.program()
def toBoolean(v):
if v in (None, False, 0):
return False
return True
class Env:
def __init__(self, outer):
self.outer = outer
self.bindings = {}
def lookup(self, name):
env = self
while env is not None:
if name in env.bindings:
return env.bindings[name]
env = env.outer
raise ValueError("%r is undefined" % name)
class UndefinedClass:
pass
Undefined = UndefinedClass()
Code = namedtuple("Code", "name args body")
Fn = namedtuple("Fn", "code env")
def apply_fn(env, fn, actual_args):
fenv = Env(env)
for name, value in zip(fn.code.args, actual_args):
fenv.bindings[name] = value
if "arguments" not in fenv.bindings:
fenv.bindings["arguments"] = actual_args
try:
evaluate(fenv, fn.code.body)
except ReturnException as exc:
return exc.value
return Undefined
def evaluate(env, ast):
ast_type = ast.__class__
if ast_type in (Block, Program):
for stmt in ast.body:
evaluate(env, stmt)
elif ast_type is FunctionDeclaration:
env.bindings[ast.code.name] = Fn(ast.code, env)
elif ast_type is IfStatement:
if toBoolean(evaluate(env, ast.test)):
evaluate(env, ast.body)
elif ast.alt is not None:
evaluate(env, ast.alt)
else:
return None
elif ast_type is WhileStatement:
while toBoolean(evaluate(env, ast.test)):
evaluate(env, ast.body)
elif ast_type is ExprStatement:
evaluate(env, ast.expr)
elif ast_type is ReturnStatement:
raise ReturnException(evaluate(env, ast.expr))
elif ast_type is CommaExpression:
evaluate(env, ast.left)
return evaluate(env, ast.right)
elif ast_type is AssignmentExpression:
evaluate(env, ast.left)
evaluate(env, ast.right)
raise ValueError("not implemented: assignment")
elif ast_type is BinaryExpression:
left = evaluate(env, ast.left)
op = ast.op
right = evaluate(env, ast.right)
if op == "+":
return left + right
elif op == "-":
return left - right
elif op == "*":
return left * right
elif op == "%":
return left % right
elif op == "/":
return left / right
raise ValueError("internal error")
elif ast_type is CallExpression:
fn = evaluate(env, ast.fn)
args = [evaluate(env, arg) for arg in ast.args]
if isinstance(fn, Fn):
return apply_fn(env, fn, args)
return fn(*args)
elif ast_type is FunctionExpression:
return Fn(env, ast.code)
elif ast_type in (BoolLiteral, NumberLiteral):
return ast.value
elif ast_type is NullLiteral:
return None
elif ast_type is Identifier:
return env.lookup(ast.name)
else:
raise ValueError("unexpected node type %s" % ast_type.__name__)
def test():
hit = False
def jouet_assert_eq(a, b):
nonlocal hit
if a != b:
raise ValueError("Assertion failed: got %r, expected %r (%r)", a, b, code)
hit = True
def run_test(code):
nonlocal hit
hit = False
program = parse(code)
env = Env(None)
env.bindings["print"] = print
env.bindings["assertEq"] = jouet_assert_eq
evaluate(env, program)
if not hit:
raise ValueError("Test failed: assertion not hit (%r)", code)
run_test("if (3*0) assertEq(1, 2); else assertEq(2, 2);")
run_test("function f(x) { assertEq(x - 1, x - 1); } f(33);")
run_test("assertEq((1, 2, 3), 3);")
run_test("function f(x) { return x + 1; } assertEq(f(3), 4);")
run_test("function fac(x) { if (x) return x * fac(x - 1); return 1; } assertEq(fac(5), 120);")
print("all tests passed")
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment