Skip to content

Instantly share code, notes, and snippets.

@mbillingr
Created June 14, 2019 10:25
Show Gist options
  • Select an option

  • Save mbillingr/f56c308d67e1dc888c5673003cf26f59 to your computer and use it in GitHub Desktop.

Select an option

Save mbillingr/f56c308d67e1dc888c5673003cf26f59 to your computer and use it in GitHub Desktop.
Some scheme compilation concepts with llvmlite
from ctypes import CFUNCTYPE, c_int32, c_voidp, c_char
import llvmlite.binding as llvm
from llvmlite import ir
# All these initializations are required for code generation!
llvm.initialize()
llvm.initialize_native_target()
llvm.initialize_native_asmprinter() # yes, even this one
def create_execution_engine():
"""
Create an ExecutionEngine suitable for JIT code generation on
the host CPU. The engine is reusable for an arbitrary number of
modules.
"""
# Create a target machine representing the host
target = llvm.Target.from_default_triple()
target_machine = target.create_target_machine()
# And an execution engine with an empty backing module
backing_mod = llvm.parse_assembly("")
engine = llvm.create_mcjit_compiler(backing_mod, target_machine)
return engine
def compile_ir(engine, llvm_ir):
"""
Compile the LLVM IR string with the given engine.
The compiled module object is returned.
"""
# Create a LLVM module object from the IR
if isinstance(llvm_ir, ir.module.Module):
mod = llvm_ir
else:
mod = llvm.parse_assembly(llvm_ir)
mod.verify()
# Now add the module and make sure it is ready for execution
engine.add_module(mod)
engine.finalize_object()
engine.run_static_constructors()
return mod
TAG_NIL = 0
TAG_NUMBER = 1
TAG_FN = 2
class Compiler:
def __init__(self, name):
self.obj_t = ir.LiteralStructType([ir.IntType(64), ir.IntType(64)])
self.cell_t = ir.IntType(32)
self.cell_pt = ir.IntType(64)
self.cell_p = ir.PointerType(self.cell_t)
self.top_fnty = ir.FunctionType(self.obj_t, ())
self.fnty_binary = ir.FunctionType(self.obj_t, (self.obj_t, self.obj_t))
self.ZERO = ir.Constant(self.cell_t, 0)
self.ONE = ir.Constant(self.cell_t, 1)
self.STEP = ir.Constant(self.cell_pt, 4)
self.env = {}
self.builder = None
self.module = ir.Module(name=name)
self.compile_builtins()
def compile_builtins(self):
# only placeholders at the mo ment
self.env["add"] = ir.Function(self.module, self.fnty_binary, name="add")
self.env["sub"] = ir.Function(self.module, self.fnty_binary, name="sub")
self.env["lookup"] = ir.Function(self.module, ir.FunctionType(self.obj_t, [ir.ArrayType(ir.IntType(8), 1)]), name="lookup") # only 1-char variables :(
def compile_top_level(self, exp):
self.compile_function([], exp, name="main")
return self.module
def compile_expression(self, exp):
if is_self_evaluating(exp):
return self.compile_self_evaluating(exp)
elif is_variable(exp):
return self.compile_variable(exp)
elif is_hardcoded(exp):
return self.compile_hardcoded(exp)
elif is_lambda(exp):
return self.compile_lambda(exp)
else:
raise ValueError("Unknown Expression Type: {}".format(exp))
def compile_self_evaluating(self, exp):
if is_integer(exp):
return ir.Constant(self.obj_t, [TAG_NUMBER, exp])
def compile_variable(self, exp):
tmp = [ir.Constant.literal_array([ir.Constant(ir.IntType(8), i) for i in exp.encode()])]
return self.builder.call(self.env["lookup"], tmp, tail=True)
def compile_hardcoded(self, exp):
lhs = self.compile_expression(exp[1])
rhs = self.compile_expression(exp[2])
if exp[0] == '+':
return self.builder.call(self.env["add"], [lhs, rhs], tail=True)
if exp[0] == '-':
return self.builder.call(self.env["sub"], [lhs, rhs], tail=True)
else:
raise ValueError("Unknown hardcoded function: {}".format(exp))
def compile_lambda(self, exp):
func = self.compile_function(lambda_params(exp), lambda_body(exp), lambda_name())
fptr = self.builder.ptrtoint(func, self.cell_pt, name="fptr")
return self.builder.insert_value(ir.Constant(self.obj_t, [TAG_FN, ir.Undefined]), fptr, 1)
def compile_function(self, params, body, name):
fnty = ir.FunctionType(self.obj_t, [self.obj_t]*len(params))
func = ir.Function(self.module, fnty, name=name)
block = func.append_basic_block(name="entry")
builder = self.builder
self.builder = ir.IRBuilder(block)
body = self.compile_expression(body)
self.builder.ret(body)
self.builder = builder
return func
def is_self_evaluating(exp):
"""for now: anything that is a number"""
return is_number(exp)
def is_variable(exp):
return is_string(exp)
def is_hardcoded(exp):
return is_list(exp) and exp[0] in '+-'
def is_lambda(exp):
return is_list(exp) and exp[0] == 'lambda'
def lambda_params(exp):
return exp[1]
def lambda_body(exp):
return exp[2]
def lambda_name():
lambda_name.id += 1
return 'lambda-{}'.format(lambda_name.id)
lambda_name.id = 0
def is_number(exp):
return is_integer(exp)
def is_float(exp):
return isinstance(exp, float)
def is_integer(exp):
return isinstance(exp, int)
def is_list(exp):
return isinstance(exp, list) or isinstance(exp, tuple)
def is_string(exp):
return isinstance(exp, str)
engine = create_execution_engine()
ast = ('lambda', ('x', 'y'), ('+', 'x', 'y'))
bfir = Compiler("test").compile_top_level(ast)
print(bfir)
mod = compile_ir(engine, str(bfir))
print("Before optimization:")
print(mod)
# perform some optimizations
mpm = llvm.passmanagers.create_module_pass_manager()
mpm.add_instruction_combining_pass()
mpm.add_cfg_simplification_pass()
mpm.add_global_optimizer_pass()
mpm.add_gvn_pass()
mpm.add_cfg_simplification_pass()
mpm.add_instruction_combining_pass()
mpm.run(mod)
print("After optimization:")
print(mod)
## Look up the function pointer (a Python int)
#func_ptr = engine.get_function_address("main")
#
#tape = bytearray(1024 * 4)
#out = bytearray(1024 * 4)
#
#cfunc = CFUNCTYPE(c_int32, c_voidp, c_voidp)(func_ptr)
#res = cfunc((c_char * len(tape)).from_buffer(tape),
# (c_char * len(out)).from_buffer(out))
#
#print(tape[:10])
#print(out[::4].decode())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment