Skip to content

Instantly share code, notes, and snippets.

@mbillingr
Last active June 12, 2019 09:25
Show Gist options
  • Save mbillingr/049ac36959f1de442a396604ecd41308 to your computer and use it in GitHub Desktop.
Save mbillingr/049ac36959f1de442a396604ecd41308 to your computer and use it in GitHub Desktop.
Conceptual Brainfuck JIT using llvm-lite
from ctypes import CFUNCTYPE, c_int32, c_voidp, c_char
import llvmlite.binding as llvm
from llvmlite import ir
hello_world_source = """
++++++++ Set Cell #0 to 8
[
>++++ Add 4 to Cell #1; this will always set Cell #1 to 4
[ as the cell will be cleared by the loop
>++ Add 2 to Cell #2
>+++ Add 3 to Cell #3
>+++ Add 3 to Cell #4
>+ Add 1 to Cell #5
<<<<- Decrement the loop counter in Cell #1
] Loop till Cell #1 is zero; number of iterations is 4
>+ Add 1 to Cell #2
>+ Add 1 to Cell #3
>- Subtract 1 from Cell #4
>>+ Add 1 to Cell #6
[<] Move back to the first zero cell you find; this will
be Cell #1 which was cleared by the previous loop
<- Decrement the loop Counter in Cell #0
] Loop till Cell #0 is zero; number of iterations is 8
The result of this is:
Cell No : 0 1 2 3 4 5 6
Contents: 0 0 72 104 88 32 8
Pointer : ^
>>. Cell #2 has value 72 which is 'H'
>---. Subtract 3 from Cell #3 to get 101 which is 'e'
+++++++..+++. Likewise for 'llo' from Cell #3
>>. Cell #5 is 32 for the space
<-. Subtract 1 from Cell #4 for 87 to give a 'W'
<. Cell #3 was set to 'o' from the end of 'Hello'
+++.------.--------. Cell #3 for 'rl' and 'd'
>>+. Add 1 to Cell #5 gives us an exclamation point
>++. And finally a newline from Cell #6
"""
# All these initializations are required for code generation!
llvm.initialize()
llvm.initialize_native_target()
llvm.initialize_native_asmprinter() # yes, even this one
def create_execution_engine():
"""
Create an ExecutionEngine suitable for JIT code generation on
the host CPU. The engine is reusable for an arbitrary number of
modules.
"""
# Create a target machine representing the host
target = llvm.Target.from_default_triple()
target_machine = target.create_target_machine()
# And an execution engine with an empty backing module
backing_mod = llvm.parse_assembly("")
engine = llvm.create_mcjit_compiler(backing_mod, target_machine)
return engine
def compile_ir(engine, llvm_ir):
"""
Compile the LLVM IR string with the given engine.
The compiled module object is returned.
"""
# Create a LLVM module object from the IR
if isinstance(llvm_ir, ir.module.Module):
mod = llvm_ir
else:
mod = llvm.parse_assembly(llvm_ir)
mod.verify()
# Now add the module and make sure it is ready for execution
engine.add_module(mod)
engine.finalize_object()
engine.run_static_constructors()
return mod
class BFCompiler:
def __init__(self, name):
self.cell_t = ir.IntType(32)
self.cell_pt = ir.IntType(64)
self.cell_p = ir.PointerType(self.cell_t)
self.fnty = ir.FunctionType(self.cell_t, (self.cell_p, self.cell_p))
self.ZERO = ir.Constant(self.cell_t, 0)
self.ONE = ir.Constant(self.cell_t, 1)
self.STEP = ir.Constant(self.cell_pt, 4)
self.module = ir.Module(name=name)
def compile_to_ir(self, source):
self.func = ir.Function(self.module, self.fnty, name="main")
tape, output = self.func.args
block = self.func.append_basic_block(name="entry")
builder = ir.IRBuilder(block)
cursor = builder.ptrtoint(tape, self.cell_pt, name="cursor")
output = builder.ptrtoint(output, self.cell_pt, name="output")
self.compile_sequence(builder, iter(source), cursor, output)
builder.ret(self.ZERO)
return self.module
def compile_sequence(self, builder, source, cursor, output):
for cmd in source:
if cmd == ">":
cursor = builder.add(cursor, self.STEP, name="cursor")
if cmd == "<":
cursor = builder.sub(cursor, self.STEP, name="cursor")
elif cmd == "+":
p = builder.inttoptr(cursor, self.cell_p)
cell = builder.load(p, name="cell")
cell = builder.add(cell, self.ONE, name="cell")
builder.store(cell, p)
elif cmd == "-":
p = builder.inttoptr(cursor, self.cell_p)
cell = builder.load(p, name="cell")
cell = builder.sub(cell, self.ONE, name="cell")
builder.store(cell, p)
elif cmd == ".":
p = builder.inttoptr(cursor, self.cell_p)
cell = builder.load(p, name="cell")
p = builder.inttoptr(output, self.cell_p)
builder.store(cell, p)
output = builder.add(output, self.STEP, name="output")
elif cmd == "[":
cursor, output, builder = self.compile_loop(builder, source, cursor, output)
elif cmd == "]":
# this simple treatment of ] closes all unmatched [ implicitly at the end of the input
break
return cursor, output, builder
def compile_loop(self, builder, source, cursor, output):
prev_block = builder.block
cond_block = self.func.append_basic_block(name="loop_cond")
loop_block = self.func.append_basic_block(name="loop_body")
next_block = self.func.append_basic_block(name="loop_then")
builder.branch(cond_block)
builder.position_at_start(cond_block)
phi1 = builder.phi(self.cell_pt, name='cursor')
phi1.add_incoming(cursor, prev_block)
cursor = phi1
phi2 = builder.phi(self.cell_pt, name='output')
phi2.add_incoming(output, prev_block)
output = phi2
p = builder.inttoptr(cursor, self.cell_p)
cell = builder.load(p, name="cell")
cond = builder.icmp_unsigned('==', cell, self.ZERO)
builder.cbranch(cond, next_block, loop_block)
builder.position_at_end(loop_block)
cursor, output, builder = self.compile_sequence(builder, source, cursor, output)
phi1.add_incoming(cursor, builder.block)
phi2.add_incoming(output, builder.block)
builder.branch(cond_block)
builder.position_at_end(next_block)
cursor = phi1
output = phi2
return cursor, output, builder
engine = create_execution_engine()
bfir = BFCompiler("test").compile_to_ir(hello_world_source)
print(bfir)
mod = compile_ir(engine, str(bfir))
print("Before optimization:")
print(mod)
# perform some optimizations
mpm = llvm.passmanagers.create_module_pass_manager()
mpm.add_global_optimizer_pass()
mpm.add_gvn_pass()
mpm.add_instruction_combining_pass()
mpm.run(mod)
print("After optimization:")
print(mod)
# Look up the function pointer (a Python int)
func_ptr = engine.get_function_address("main")
tape = bytearray(1024 * 4)
out = bytearray(1024 * 4)
cfunc = CFUNCTYPE(c_int32, c_voidp, c_voidp)(func_ptr)
res = cfunc((c_char * len(tape)).from_buffer(tape),
(c_char * len(out)).from_buffer(out))
print(tape[:10])
print(out[::4].decode())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment