Created
February 4, 2021 04:46
-
-
Save suica/575a0d7065ec811678a0a20f4b1d0f0b to your computer and use it in GitHub Desktop.
A simplistic Scheme interpreter
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import reduce | |
from typing import List, Tuple | |
import operator | |
class SchemeList: | |
next = None | |
def __init__(self, value): | |
self.value = value | |
def get_nested(self): | |
return [self.get_value(), | |
self.get_next().get_nested() if isinstance(self.get_next(), SchemeList) else self.get_next()] | |
def get_next(self): | |
return self.next | |
def get_value(self): | |
return self.value | |
def __repr__(self): | |
return str(self.get_nested()) | |
@staticmethod | |
def cons(a, b): | |
root = SchemeList(a) | |
root.next = b | |
return root | |
class SchemeQuoteList: | |
def __init__(self, value): | |
self.value = value | |
def __repr__(self): | |
return "quote".format(self.value) | |
def tokenize(s: str) -> List[str]: | |
tokens = [] | |
current_token = '' | |
for i, c in enumerate(s): | |
if c == '(' or c == ')': | |
tokens.append(current_token) | |
current_token = '' | |
tokens.append(c) | |
elif c == ' ': | |
tokens.append(current_token) | |
current_token = '' | |
elif '0' <= c <= '9' or c == '.': | |
current_token += c | |
elif c in '\t\n': | |
continue | |
else: | |
current_token += c | |
tokens.append(current_token) | |
return [token for token in tokens if token != ''] | |
def parse(tokens: List[str], start=0): | |
tree = [] | |
index = start | |
while index < len(tokens): | |
current_token = tokens[index] | |
if current_token == ')': | |
return tree, index | |
elif current_token == '(': | |
subtree, end_i = parse(tokens, start=index + 1) | |
tree.append(tuple(subtree)) | |
index = end_i + 1 | |
continue | |
elif current_token == "'": | |
subtree, end_i = parse(tokens, start=index + 1) | |
tree.append((SchemeQuoteList( | |
('list',) + tuple(subtree[0]), | |
))) | |
index = end_i | |
continue | |
else: | |
tree.append(current_token) | |
index += 1 | |
return tree, index | |
def apply(expr: Tuple, context): | |
if len(expr) > 0: | |
op = expr[0] | |
if isinstance(op, str): | |
is_special_form = op in ['define', 'if', 'lambda', 'cond', 'and', 'or', 'list', 'car', 'cdr'] | |
if is_special_form: | |
# 特殊型 | |
operands = expr[1:] | |
if op == 'define': | |
middle, *body = operands | |
if len(body) != 1: | |
raise NotImplementedError | |
body = body[0] | |
if isinstance(middle, tuple): | |
# 是函数定义 | |
name, *parameters = middle | |
context[name] = apply(tuple(['lambda', parameters, body]), context.copy()) | |
else: | |
# 是常量定义 | |
name = middle | |
context[name] = evaluate(body, context) | |
return 'def:{}'.format(name) | |
elif op == 'if': | |
if len(operands) != 3: | |
raise Exception('if expression is malformed') | |
predicate, consequent, alternative = operands | |
if evaluate(predicate, context) is True: | |
return evaluate(consequent, context) | |
else: | |
return evaluate(alternative, context) | |
raise NotImplementedError | |
elif op == 'cond': | |
raise NotImplementedError | |
elif op == 'lambda': | |
if len(operands) != 2: | |
raise Exception('lambda expression is malformed: {}'.format(expr)) | |
[*parameters], body = operands | |
new_context = context.copy() | |
for p in parameters: | |
new_context[p] = None | |
def _lambda(*args): | |
if len(args) != len(parameters): | |
raise Exception('arity error, expect {} but given {}'.format(len(parameters), len(args))) | |
for p, a in zip(parameters, args): | |
new_context[p] = a | |
if isinstance(body, tuple): | |
return apply(body, new_context) | |
return evaluate(body, new_context) | |
return _lambda | |
elif op == 'and': | |
for unevaluated in operands: | |
temp = evaluate(unevaluated, context) | |
if temp is False: | |
return False | |
return True | |
elif op == 'or': | |
for unevaluated in operands: | |
temp = evaluate(unevaluated, context) | |
if temp is True: | |
return True | |
return False | |
elif op == 'list': | |
root = SchemeList(None) | |
cur = root | |
for operand in operands: | |
cur.next = SchemeList(evaluate(operand, context)) | |
cur = cur.next | |
return root.next | |
elif op in ['car', 'cdr']: | |
if len(operands) != 1: | |
raise Exception("arity error for car/cdr") | |
operand = evaluate(operands[0], context) | |
if isinstance(operand, SchemeQuoteList): | |
operand = evaluate(operand.value, context) | |
if isinstance(operand, SchemeList): | |
return operand.get_value() if op == 'car' else operand.get_next() | |
raise Exception('car/cdr error: {} is not a list'.format(operand)) | |
else: | |
# 这意味着, op是个函数 | |
# 我们可以立即求值其所有参数 | |
operands = [evaluate(operand, context) for operand in expr[1:]] | |
if op == '+': | |
return sum(operands) | |
elif op == '-': | |
return operands[0] - sum(operands[1:]) | |
elif op == '*': | |
return reduce(lambda pre, cur: pre * cur, operands, 1) | |
elif op == '/': | |
raise NotImplementedError | |
elif op in context: | |
func = context[op] | |
if callable(func): | |
return func(*operands) | |
else: | |
raise Exception('{} is not callable'.format(op)) | |
elif isinstance(op, tuple): | |
operands = [evaluate(operand, context) for operand in expr[1:]] | |
func = apply(op, context) | |
try: | |
return func(*operands) | |
except Exception as e: | |
print(e) | |
raise TypeError | |
raise Exception('unrecognised form {}'.format(op)) | |
else: | |
raise Exception('empty expression to apply') | |
def evaluate(expr, context) -> List: | |
if isinstance(expr, list): | |
if len(expr) == 0: | |
raise Exception('attempt to evaluate an empty expression') | |
results = [] | |
for item in expr: | |
if isinstance(item, tuple): | |
results.append(apply(item, context)) | |
else: | |
results.append(evaluate(item, context)) | |
return results | |
elif isinstance(expr, tuple): | |
return apply(expr, context) | |
elif isinstance(expr, str): | |
try: | |
return int(expr) | |
except: | |
if expr in context: | |
return context[expr] | |
raise Exception("{} is not in context {}".format(expr, context)) | |
elif isinstance(expr, SchemeQuoteList): | |
return expr | |
raise TypeError(expr) | |
def format_result(lis): | |
result = [] | |
for item in lis: | |
if isinstance(item, tuple): | |
raise NotImplementedError | |
elif callable(item): | |
result.append('lambda') | |
else: | |
result.append(item) | |
return result | |
def equal(list_a: SchemeList, list_b: SchemeList): | |
if list_a is None and list_b is None: | |
return True | |
if list_a is None or list_b is None: | |
return False | |
if isinstance(list_a, SchemeList) and isinstance(list_a, SchemeList): | |
return (list_a.get_value() == list_b.get_value()) and equal(list_a.get_next(), list_b.get_next()) | |
return list_a == list_b | |
predefined_context = { | |
'true': True, | |
'false': False, | |
'<': operator.lt, | |
'<=': operator.le, | |
'>': operator.gt, | |
'>=': operator.ge, | |
'=': operator.eq, | |
'not': lambda x: not x, | |
'null?': lambda x: x is None, | |
'equal?': equal, | |
'cons': lambda a, b: SchemeList.cons(a, b), | |
'nil': None, | |
} | |
# entry | |
def tokenize_and_evaluate(s): | |
return format_result(evaluate(parse(tokenize(s))[0], predefined_context)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
from unittest import skip | |
import scheme_interpreter | |
class SchemeInterpreterTestCases(unittest.TestCase): | |
def test_tokenize(self): | |
cases = [ | |
['0', ['0']], | |
['(0)', ['(', '0', ')']], | |
['(+ 1 (- 1 2))', ['(', '+', '1', '(', '-', '1', '2', ')', ')']], | |
['(+ 1 22222)', ['(', '+', '1', '22222', ')']], | |
["' (1 2 )", ["'", '(', '1', '2', ')']] | |
] | |
for _input, expected in cases: | |
result = scheme_interpreter.tokenize(_input) | |
self.assertEqual(expected, result) | |
def test_parse(self): | |
self.assertEqual(['0', '1', '2'], scheme_interpreter.parse(scheme_interpreter.tokenize('0 1 2'))[0]) | |
self.assertEqual([('+', '91', '9')], scheme_interpreter.parse(scheme_interpreter.tokenize('(+ 91 9)'))[0]) | |
self.assertEqual([('+', '91', ('*', '3', '3'))], | |
scheme_interpreter.parse(scheme_interpreter.tokenize('(+ 91 (* 3 3))'))[0]) | |
self.assertEqual([('+', '91', ('*', '3', '3')), ('+', '1', '2')], | |
scheme_interpreter.parse(scheme_interpreter.tokenize('(+ 91 (* 3 3)) (+ 1 2)'))[0]) | |
parsed = (scheme_interpreter.parse(scheme_interpreter.tokenize( | |
""" | |
(car (list 1)) | |
(car (list 1)) | |
""" | |
)))[0] | |
self.assertEqual(2, len(parsed)) | |
def test_parse_with_quote(self): | |
parsed = (scheme_interpreter.parse(scheme_interpreter.tokenize( | |
""" | |
(car '(1)) | |
""" | |
)))[0] | |
self.assertEqual(1, len(parsed)) | |
parsed = (scheme_interpreter.parse(scheme_interpreter.tokenize( | |
""" | |
(car '(1)) | |
(car '(1)) | |
""" | |
)))[0] | |
self.assertEqual(2, len(parsed)) | |
self.assertEqual(2, len(parsed[0])) | |
self.assertEqual(2, len(parsed[1])) | |
def test_evaluate(self): | |
cases = [ | |
['0', [0]], | |
['0 114514 2', [0, 114514, 2]], | |
['(+ 1 2)', [3]], | |
['(* 100 (+ 1 2))', [300]], | |
['(+ (* 3 (+ (* 2 4) (+ 3 5))) (+ (- 10 7) 6))', [57]] | |
] | |
for _input, expected in cases: | |
result = scheme_interpreter.tokenize_and_evaluate(_input) | |
self.assertEqual(expected, result) | |
def test_lambda(self): | |
cases = [ | |
['((lambda (f x) (+ x f)) 1 2)', [3]], | |
['((lambda () 1))', [1]], | |
['(lambda () 1)', ['lambda']], | |
['((lambda (wocao ssd c) (+ ssd 3)) 1 2 3)', [5]] | |
] | |
for _input, expected in cases: | |
result = scheme_interpreter.tokenize_and_evaluate(_input) | |
self.assertEqual(expected, result) | |
def test_define(self): | |
cases = [ | |
[ | |
'(define (square x) (* x x)) (square 10)', | |
['def:square', 100] | |
], | |
[ | |
'(define money 1) money', | |
['def:money', 1] | |
], | |
[ | |
'(define (square x) (* x x))', | |
['def:square'] | |
], | |
[ | |
'((lambda () 1))', | |
[1] | |
], | |
['((lambda (wocao ssd c) (+ ssd 3)) 1 2 3)', [5]], | |
[ | |
'(define pi 314) ((lambda (x) (* 2 x)) pi)', | |
['def:pi', 628] | |
], | |
['(define (infinite_loop x) (infinite_loop))', ['def:infinite_loop']], | |
['(define (infinite_loop) (infinite_loop))', ['def:infinite_loop']] | |
] | |
for _input, expected in cases: | |
result = scheme_interpreter.tokenize_and_evaluate(_input) | |
self.assertEqual(expected, result) | |
def test_multiline_program(self): | |
cases = [ | |
[ | |
"(define (double x) (* x 2)) (double 2333)", | |
['def:double', 4666] | |
], | |
[ | |
""" | |
( define pi 314) | |
((lambda ( x) ( * 2 x)) pi) | |
""", | |
[ | |
'def:pi', | |
628 | |
] | |
], | |
[ | |
""" | |
(define ( double | |
x) ( * x 2 )) | |
(double 2333) | |
(define pi 314) | |
((lambda (x) (* 2 x)) pi) | |
""", | |
['def:double', 4666, 'def:pi', 628] | |
], | |
[ | |
""" | |
(define pi 314) \n\n\n ((\n lambda\n (x)\n (*\n 2\n x)) pi) | |
pi | |
""", | |
['def:pi', 628, 314] | |
], | |
] | |
for _input, expected in cases: | |
result = scheme_interpreter.tokenize_and_evaluate(_input) | |
self.assertEqual(expected, result) | |
def test_compare_operators(self): | |
# < | |
self.assertEqual(scheme_interpreter.tokenize_and_evaluate(""" | |
(< 1 2) | |
(< 2 1) | |
"""), [True, False]) | |
# not | |
self.assertEqual(scheme_interpreter.tokenize_and_evaluate(""" | |
(not (< 1 2)) | |
(not false) | |
"""), [False, True]) | |
# and | |
self.assertEqual(scheme_interpreter.tokenize_and_evaluate(""" | |
(and true (> (+ 1 2) 3)) | |
(define (x) (x)) | |
(or true (x) (哈哈哈哈随便什么都行!因为这个根本不求值)) | |
(and false (还行吧)) | |
(or (and true false) (not false)) | |
"""), [False, 'def:x', True, False, True]) | |
def test_if(self): | |
cases = [ | |
['true', [True]], | |
['false', [False]], | |
[ | |
'(if true 1 2)', | |
[1] | |
], | |
[ | |
'(if false 1 2)', | |
[2] | |
], | |
[ | |
""" | |
(define (infinite_loop) (infinite_loop)) | |
(if false (infinite_loop) 2) | |
""", | |
['def:infinite_loop', 2] | |
], | |
[ | |
("(define (double x) (* x 2))\n" | |
"(if false (what ever 什么都哈哈哈哈行???) (double 2))\n"), | |
['def:double', 4] | |
], | |
[ | |
""" | |
(if | |
(< 1 2) (+ (if (= 1 2) 1 2) 1) | |
(aksdjhaskjdhakshdkj)) | |
""", | |
[3] | |
], | |
] | |
for _input, expected in cases: | |
result = scheme_interpreter.tokenize_and_evaluate(_input) | |
self.assertEqual(expected, result) | |
def test_list(self): | |
self.assertEqual([True, False, True, True] | |
, scheme_interpreter.tokenize_and_evaluate("(equal? (list 1111 23 4) (list 1111 23 4))" | |
"(equal? (cons 1 2) (list 1 2))" | |
"(equal? (cons 1 (cons 2 nil)) (list 1 2))" | |
"(equal? (cons 1 nil) (list 1))")) | |
self.assertEqual([1, None], scheme_interpreter.tokenize_and_evaluate("(car (cons 1 nil))" | |
"(cdr (cons 1 nil))" | |
)) | |
def test_quote(self): | |
self.assertEqual( | |
[1, 1, 2222, True] | |
, scheme_interpreter.tokenize_and_evaluate("(car '(1(define(x)(x)) 3))" | |
"(car '(1)) (car '(2222 1)) " | |
"(equal? (cdr '(222 1)) (cons 1 nil))")) | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment