Created
June 14, 2019 10:25
-
-
Save mbillingr/f56c308d67e1dc888c5673003cf26f59 to your computer and use it in GitHub Desktop.
Some scheme compilation concepts with llvmlite
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from ctypes import CFUNCTYPE, c_int32, c_voidp, c_char | |
| import llvmlite.binding as llvm | |
| from llvmlite import ir | |
| # All these initializations are required for code generation! | |
| llvm.initialize() | |
| llvm.initialize_native_target() | |
| llvm.initialize_native_asmprinter() # yes, even this one | |
| def create_execution_engine(): | |
| """ | |
| Create an ExecutionEngine suitable for JIT code generation on | |
| the host CPU. The engine is reusable for an arbitrary number of | |
| modules. | |
| """ | |
| # Create a target machine representing the host | |
| target = llvm.Target.from_default_triple() | |
| target_machine = target.create_target_machine() | |
| # And an execution engine with an empty backing module | |
| backing_mod = llvm.parse_assembly("") | |
| engine = llvm.create_mcjit_compiler(backing_mod, target_machine) | |
| return engine | |
| def compile_ir(engine, llvm_ir): | |
| """ | |
| Compile the LLVM IR string with the given engine. | |
| The compiled module object is returned. | |
| """ | |
| # Create a LLVM module object from the IR | |
| if isinstance(llvm_ir, ir.module.Module): | |
| mod = llvm_ir | |
| else: | |
| mod = llvm.parse_assembly(llvm_ir) | |
| mod.verify() | |
| # Now add the module and make sure it is ready for execution | |
| engine.add_module(mod) | |
| engine.finalize_object() | |
| engine.run_static_constructors() | |
| return mod | |
| TAG_NIL = 0 | |
| TAG_NUMBER = 1 | |
| TAG_FN = 2 | |
| class Compiler: | |
| def __init__(self, name): | |
| self.obj_t = ir.LiteralStructType([ir.IntType(64), ir.IntType(64)]) | |
| self.cell_t = ir.IntType(32) | |
| self.cell_pt = ir.IntType(64) | |
| self.cell_p = ir.PointerType(self.cell_t) | |
| self.top_fnty = ir.FunctionType(self.obj_t, ()) | |
| self.fnty_binary = ir.FunctionType(self.obj_t, (self.obj_t, self.obj_t)) | |
| self.ZERO = ir.Constant(self.cell_t, 0) | |
| self.ONE = ir.Constant(self.cell_t, 1) | |
| self.STEP = ir.Constant(self.cell_pt, 4) | |
| self.env = {} | |
| self.builder = None | |
| self.module = ir.Module(name=name) | |
| self.compile_builtins() | |
| def compile_builtins(self): | |
| # only placeholders at the mo ment | |
| self.env["add"] = ir.Function(self.module, self.fnty_binary, name="add") | |
| self.env["sub"] = ir.Function(self.module, self.fnty_binary, name="sub") | |
| self.env["lookup"] = ir.Function(self.module, ir.FunctionType(self.obj_t, [ir.ArrayType(ir.IntType(8), 1)]), name="lookup") # only 1-char variables :( | |
| def compile_top_level(self, exp): | |
| self.compile_function([], exp, name="main") | |
| return self.module | |
| def compile_expression(self, exp): | |
| if is_self_evaluating(exp): | |
| return self.compile_self_evaluating(exp) | |
| elif is_variable(exp): | |
| return self.compile_variable(exp) | |
| elif is_hardcoded(exp): | |
| return self.compile_hardcoded(exp) | |
| elif is_lambda(exp): | |
| return self.compile_lambda(exp) | |
| else: | |
| raise ValueError("Unknown Expression Type: {}".format(exp)) | |
| def compile_self_evaluating(self, exp): | |
| if is_integer(exp): | |
| return ir.Constant(self.obj_t, [TAG_NUMBER, exp]) | |
| def compile_variable(self, exp): | |
| tmp = [ir.Constant.literal_array([ir.Constant(ir.IntType(8), i) for i in exp.encode()])] | |
| return self.builder.call(self.env["lookup"], tmp, tail=True) | |
| def compile_hardcoded(self, exp): | |
| lhs = self.compile_expression(exp[1]) | |
| rhs = self.compile_expression(exp[2]) | |
| if exp[0] == '+': | |
| return self.builder.call(self.env["add"], [lhs, rhs], tail=True) | |
| if exp[0] == '-': | |
| return self.builder.call(self.env["sub"], [lhs, rhs], tail=True) | |
| else: | |
| raise ValueError("Unknown hardcoded function: {}".format(exp)) | |
| def compile_lambda(self, exp): | |
| func = self.compile_function(lambda_params(exp), lambda_body(exp), lambda_name()) | |
| fptr = self.builder.ptrtoint(func, self.cell_pt, name="fptr") | |
| return self.builder.insert_value(ir.Constant(self.obj_t, [TAG_FN, ir.Undefined]), fptr, 1) | |
| def compile_function(self, params, body, name): | |
| fnty = ir.FunctionType(self.obj_t, [self.obj_t]*len(params)) | |
| func = ir.Function(self.module, fnty, name=name) | |
| block = func.append_basic_block(name="entry") | |
| builder = self.builder | |
| self.builder = ir.IRBuilder(block) | |
| body = self.compile_expression(body) | |
| self.builder.ret(body) | |
| self.builder = builder | |
| return func | |
| def is_self_evaluating(exp): | |
| """for now: anything that is a number""" | |
| return is_number(exp) | |
| def is_variable(exp): | |
| return is_string(exp) | |
| def is_hardcoded(exp): | |
| return is_list(exp) and exp[0] in '+-' | |
| def is_lambda(exp): | |
| return is_list(exp) and exp[0] == 'lambda' | |
| def lambda_params(exp): | |
| return exp[1] | |
| def lambda_body(exp): | |
| return exp[2] | |
| def lambda_name(): | |
| lambda_name.id += 1 | |
| return 'lambda-{}'.format(lambda_name.id) | |
| lambda_name.id = 0 | |
| def is_number(exp): | |
| return is_integer(exp) | |
| def is_float(exp): | |
| return isinstance(exp, float) | |
| def is_integer(exp): | |
| return isinstance(exp, int) | |
| def is_list(exp): | |
| return isinstance(exp, list) or isinstance(exp, tuple) | |
| def is_string(exp): | |
| return isinstance(exp, str) | |
| engine = create_execution_engine() | |
| ast = ('lambda', ('x', 'y'), ('+', 'x', 'y')) | |
| bfir = Compiler("test").compile_top_level(ast) | |
| print(bfir) | |
| mod = compile_ir(engine, str(bfir)) | |
| print("Before optimization:") | |
| print(mod) | |
| # perform some optimizations | |
| mpm = llvm.passmanagers.create_module_pass_manager() | |
| mpm.add_instruction_combining_pass() | |
| mpm.add_cfg_simplification_pass() | |
| mpm.add_global_optimizer_pass() | |
| mpm.add_gvn_pass() | |
| mpm.add_cfg_simplification_pass() | |
| mpm.add_instruction_combining_pass() | |
| mpm.run(mod) | |
| print("After optimization:") | |
| print(mod) | |
| ## Look up the function pointer (a Python int) | |
| #func_ptr = engine.get_function_address("main") | |
| # | |
| #tape = bytearray(1024 * 4) | |
| #out = bytearray(1024 * 4) | |
| # | |
| #cfunc = CFUNCTYPE(c_int32, c_voidp, c_voidp)(func_ptr) | |
| #res = cfunc((c_char * len(tape)).from_buffer(tape), | |
| # (c_char * len(out)).from_buffer(out)) | |
| # | |
| #print(tape[:10]) | |
| #print(out[::4].decode()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment