Created
December 22, 2019 16:44
-
-
Save jmikkola/eeb18e8ae590db955d325a3cb74a5a70 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
from ctypes import c_long, CFUNCTYPE | |
import llvmlite.binding as llvm | |
from llvmlite import ir | |
llvm.initialize() | |
llvm.initialize_native_target() | |
llvm.initialize_native_asmprinter() | |
# If quoted strings are supported, add: | |
# |"(?:[^"\\]|\\.)*" | |
# right after the pattern for comments | |
token_re = re.compile(r'([(]|[)]|;;[^\n]*\n|[^\s()]+|\s+|\n)') | |
int_re = re.compile(r'^\d+$') | |
int64 = ir.IntType(64) | |
voidptr = ir.IntType(8).as_pointer() | |
def make_int(i): | |
return int64(int(i)) | |
def make_fn_type(n_args): | |
arg_types = tuple([int64] * n_args) | |
return ir.FunctionType(int64, arg_types) | |
def tokenize(text): | |
tokens = token_re.findall(text) | |
return [t for t in tokens if not t.isspace()] | |
def parse(tokens): | |
stack = ([], None) | |
for t in tokens: | |
if t == '(': | |
stack = ([], stack) | |
elif t == ')': | |
(finished_list, stack) = stack | |
stack[0].append(finished_list) | |
elif not t.startswith(';;'): | |
stack[0].append(t) | |
return stack[0] | |
class Matcher: | |
def __init__(self, ast): | |
self.ast = ast | |
self.match = None | |
def matches(self, pattern): | |
self.match = pattern(self.ast) | |
return self.match is not None | |
def int_pattern(ast): | |
if isinstance(ast, str) and int_re.match(ast): | |
return [int(ast)] | |
return None | |
def atom_pattern(ast): | |
if isinstance(ast, str): | |
return [ast] | |
return None | |
def match_definition(ast, definition): | |
if isinstance(definition, str): | |
if definition == '$': | |
return [ast] | |
elif definition == '$l': | |
if isinstance(ast, list): | |
return [ast] | |
return None | |
elif definition == '$a': | |
if isinstance(ast, str): | |
return [ast] | |
return None | |
else: | |
if ast == definition: | |
return [] # Don't bother including literals in the results | |
else: | |
return None | |
if isinstance(definition, set): | |
if ast in definition: | |
return [ast] | |
return None | |
else: | |
if isinstance(ast, str): | |
return None | |
if definition[-1] != '*' and len(definition) != len(ast): | |
return None | |
results = [] | |
for (node, defn) in zip(ast, definition): | |
if defn == '*': | |
break | |
result = match_definition(node, defn) | |
if result is None: | |
return None | |
results.extend(result) | |
if definition[-1] == '*': | |
results.append(ast[len(definition)-1:]) | |
return results | |
def pattern(definition): | |
return lambda ast: match_definition(ast, definition) | |
binary_operators = set('+ - * / == > <'.split()) | |
def indent(depth): | |
return ' ' * depth | |
def print_indent(s, depth): | |
if False: | |
print(indent(depth) + s) | |
class Context: | |
def __init__(self, module): | |
self.module = module | |
self.printf = None | |
self.global_fmt = None | |
self.builder = None | |
self.block = None | |
def get_function(self, name): | |
for f in self.module.functions: | |
if f.name == name: | |
return f | |
def recognize_top(ast, context, depth=0): | |
m = Matcher(ast) | |
if m.matches(pattern(['var', '$a', '$'])): | |
print_indent('global var {} {}'.format(*m.match), depth) | |
recognize(m.match[1], context, {}, depth+1) | |
elif m.matches(pattern(['defn', '$a', '$l', '$'])): | |
print_indent('function {}({}) = {}'.format(*m.match), depth) | |
name, args, body = m.match | |
func = ir.Function(context.module, make_fn_type(len(args)), name=name) | |
scope = {name: arg for (name, arg) in zip(args, func.args)} | |
block = func.append_basic_block(name='entry') | |
context.block = block | |
builder = ir.IRBuilder(block) | |
context.builder = builder | |
result = recognize(m.match[2], context, scope, depth+1) | |
builder.ret(result) | |
context.builder = None | |
context.block = None | |
else: | |
print('invalid top-level expression') | |
def recognize(ast, context, scope, depth=0): | |
m = Matcher(ast) | |
if m.matches(int_pattern): | |
num = int(m.match[0]) | |
print_indent('int: {}'.format(num), depth) | |
return make_int(num) | |
elif m.matches(atom_pattern): | |
atom = m.match[0] | |
print_indent('var: {}'.format(atom), depth) | |
if atom in scope: | |
return scope[atom] | |
return context.module.get_global(atom) | |
elif m.matches(pattern([binary_operators, '$', '$'])): | |
print_indent('bin_op: {} {} {}'.format(m.match[0], m.match[1], m.match[2]), depth) | |
left = recognize(m.match[1], context, scope, depth+1) | |
right = recognize(m.match[2], context, scope, depth+1) | |
op = m.match[0] | |
if op == '+': | |
return context.builder.add(left, right) | |
elif op == '-': | |
return context.builder.sub(left, right) | |
elif op == '*': | |
return context.builder.mul(left, right) | |
elif op == '/': | |
return context.builder.div(left, right) | |
elif op in ('<', '==', '>'): | |
return context.builder.icmp_signed(op, left, right) | |
else: | |
print('todo, handle op ' + op) | |
elif m.matches(pattern(['if', '$', '$', '$'])): | |
print_indent('if2 {} {} {}'.format(*m.match), depth) | |
test = recognize(m.match[0], context, scope, depth+1) | |
with context.builder.if_else(test) as (then, otherwise): | |
with then: | |
then_block = context.builder.block | |
then_result = recognize(m.match[1], context, scope, depth+1) | |
with otherwise: | |
else_block = context.builder.block | |
else_result = recognize(m.match[2], context, scope, depth+1) | |
result = context.builder.phi(int64, name="ifresult") | |
result.add_incoming(then_result, then_block) | |
result.add_incoming(else_result, else_block) | |
return result | |
elif m.matches(pattern(['print', '$'])): | |
print_indent('print {}'.format(m.match[0]), depth) | |
val = recognize(m.match[0], context, scope, depth+1) | |
fmt_arg = context.builder.bitcast(context.global_fmt, voidptr) | |
context.builder.call(context.printf, [fmt_arg, val]) | |
return make_int(0) | |
elif m.matches(pattern(['do', '*'])): | |
print_indent('do {}'.format(m.match[0]), depth) | |
result = None | |
for stmt in m.match[0]: | |
result = recognize(stmt, context, scope, depth+1) | |
if result is None: | |
result = make_int(0) | |
return result | |
elif m.matches(pattern(['$a', '*'])): | |
print_indent('call {} with {}'.format(*m.match), depth) | |
fn_name, args = m.match | |
function = context.get_function(fn_name) | |
if len(function.args) != len(args): | |
raise Exception('wrong number of args for ' + fn_name) | |
arg_values = [recognize(arg, context, scope, depth+1) for arg in args] | |
return context.builder.call(function, arg_values, 'calltmp') | |
else: | |
print_indent("todo: {}".format(ast), depth) | |
def build(text): | |
module = ir.Module(name='testmodule') | |
context = Context(module=module) | |
print_fmt = "%d\n\0" | |
c_print_fmt = ir.Constant( | |
ir.ArrayType(ir.IntType(8), len(print_fmt)), | |
bytearray(print_fmt.encode("utf8")), | |
) | |
global_fmt = ir.GlobalVariable(module, c_print_fmt.type, name="print_fmt") | |
global_fmt.linkage = 'internal' | |
global_fmt.global_constant = True | |
global_fmt.initializer = c_print_fmt | |
context.global_fmt = global_fmt | |
printf_ty = ir.FunctionType(ir.IntType(32), [voidptr], var_arg=True) | |
printf = ir.Function(module, printf_ty, name="printf") | |
context.printf = printf | |
parsed = parse(tokenize(text)) | |
for ast in parsed: | |
recognize_top(ast, context) | |
# print() | |
print(module) | |
return module | |
def create_execution_engine(): | |
target = llvm.Target.from_default_triple() | |
target_machine = target.create_target_machine() | |
backing_mod = llvm.parse_assembly('') | |
engine = llvm.create_mcjit_compiler(backing_mod, target_machine) | |
return engine | |
def compile_ir(engine, llvm_ir): | |
module = llvm.parse_assembly(llvm_ir) | |
module.verify() | |
engine.add_module(module) | |
engine.finalize_object() | |
engine.run_static_constructors() | |
module = build(''' | |
(defn add (a b) | |
(+ a b)) | |
(defn fib (n) | |
(if (< n 3) | |
1 | |
(+ (fib (- n 1)) (fib (- n 2))))) | |
(defn main () | |
(do | |
(print (add 1 2)) | |
(print (fib 37)))) | |
''') | |
engine = create_execution_engine() | |
compile_ir(engine, str(module)) | |
func_ptr = engine.get_function_address("main") | |
main = CFUNCTYPE(c_long)(func_ptr) | |
main() | |
# target triple = "x86_64-pc-linux-gnu" | |
# @.str = private unnamed_addr constant [4 x i8] c"%d\0A\00", align 1 | |
# declare i32 @printf(i8*, ...) #1 | |
# %1 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([4 x i8], [4 x i8]* @.str, i32 0, i32 0), i32 12345) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment