Skip to content

Instantly share code, notes, and snippets.

@38
Created June 28, 2018 02:31
Show Gist options
  • Save 38/427895789c3e79f9df08ffff7890a1ff to your computer and use it in GitHub Desktop.
Save 38/427895789c3e79f9df08ffff7890a1ff to your computer and use it in GitHub Desktop.
#include <stdio.h>
#include <unistd.h>
#include <stdlib.h>
#include <sys/wait.h>
#include <dlfcn.h>
void compile_brainfuck_to_c(const char* program, FILE* fout)
{
fprintf(fout, "void bf_main(char* data) { char* ptr = data;");
for(;program && *program; program ++)
{
switch(*program)
{
case '+':
fprintf(fout, "(*ptr)++;"); break;
case '-':
fprintf(fout, "(*ptr)--;"); break;
case '>':
fprintf(fout, "ptr++;"); break;
case '<':
fprintf(fout, "ptr--;"); break;
case '.':
fprintf(fout, "putchar(*ptr);"); break;
case ',':
fprintf(fout, "*ptr = getchar();"); break;
case '[':
fprintf(fout, "while(*ptr){"); break;
case ']':
fprintf(fout, "}"); break;
}
}
fprintf(fout, "}");
}
int main(int argc, char** argv)
{
char buf[65536];
FILE* fp = fopen(argv[1], "r");
int n = fread(buf, 1, sizeof(buf), fp);
fclose(fp);
buf[n] = 0;
fp = fopen("hello.c", "w");
fprintf(fp, "#include <stdio.h>\n");
compile_brainfuck_to_c(buf, fp);
fclose(fp);
if(fork() == 0)
{
char arg_buf[][20] = {"gcc", "-O3", "-o", "libhello.so", "hello.c", "-shared", "-fPIC" };
char* args[] = {arg_buf[0], arg_buf[1], arg_buf[2], arg_buf[3], arg_buf[4], arg_buf[5], arg_buf[6], NULL};
execvp(args[0], args);
exit(1);
}
else wait(NULL);
unlink("hello.c");
void* dl_handle = dlopen("./libhello.so", RTLD_LAZY);
void (*bf_main)(char*) = dlsym(dl_handle, "bf_main");
char mem[65536] = {};
bf_main(mem);
unlink("libhello.so");
return 0;
}
#include <llvm/IR/IRBuilder.h>
#include <llvm/Support/TargetSelect.h>
#include <llvm/ExecutionEngine/ExecutionEngine.h>
#include <llvm/ExecutionEngine/JITSymbol.h>
#include <llvm/ExecutionEngine/SectionMemoryManager.h>
#include <llvm/ExecutionEngine/Orc/CompileUtils.h>
#include <llvm/ExecutionEngine/Orc/IRCompileLayer.h>
#include <llvm/ExecutionEngine/Orc/LambdaResolver.h>
#include <llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h>
#include <llvm/Transforms/InstCombine/InstCombine.h>
#include <llvm/Transforms/Scalar.h>
#include <llvm/Transforms/Scalar/GVN.h>
#include <llvm/Transforms/IPO/PassManagerBuilder.h>
#include <memory>
#include <vector>
#include <dlfcn.h>
using namespace llvm;
using namespace std;
static Module* jit_module = NULL;
static LLVMContext context;
static IRBuilder<> builder(context);
struct Compiler {
using CompileResult = object::OwningBinary<object::ObjectFile>;
Compiler(TargetMachine &TM)
: TM(TM){}
/// @brief Compile a Module to an ObjectFile.
CompileResult operator()(Module &M) {
SmallVector<char, 0> ObjBufferSV;
raw_svector_ostream ObjStream(ObjBufferSV);
PassManagerBuilder pmb;
pmb.OptLevel = 3;
legacy::PassManager PM;
MCContext *Ctx;
if (TM.addPassesToEmitMC(PM, Ctx, ObjStream))
llvm_unreachable("Target does not support MC emission.");
pmb.populateModulePassManager(PM);
PM.run(M);
std::unique_ptr<MemoryBuffer> ObjBuffer(
new ObjectMemoryBuffer(std::move(ObjBufferSV)));
Expected<std::unique_ptr<object::ObjectFile>> Obj =
object::ObjectFile::createObjectFile(ObjBuffer->getMemBufferRef());
if (Obj)
return CompileResult(std::move(*Obj), std::move(ObjBuffer));
consumeError(Obj.takeError());
return CompileResult(nullptr, nullptr);
}
TargetMachine &TM;
};
using LinkLayer = orc::RTDyldObjectLinkingLayer;
//using Compiler = orc::SimpleCompiler;
using CompileLayer = orc::IRCompileLayer<LinkLayer, Compiler>;
void* handle = dlopen(NULL, 0);
JITSymbol dummy_lookup(const string& name)
{
uintptr_t p = (uintptr_t)dlsym(handle, name.c_str());
return JITSymbol(p, JITSymbolFlags::Exported);
}
void code_gen(const char* program)
{
jit_module = new Module("Test JIT Compiler", context);
// (char* mem)
std::vector<Type*> param_type(1, Type::getInt8PtrTy(context));
// void (*)(char* mem)
FunctionType* prototype = FunctionType::get(Type::getVoidTy(context), param_type, false);
std::vector<Type*> putc_param(1, Type::getInt32Ty(context));
FunctionType* func_putc_proto = FunctionType::get(Type::getInt32Ty(context), putc_param, false);
Function* func_putc = Function::Create(func_putc_proto, Function::ExternalLinkage, "putchar", jit_module);
std::vector<Type*> getc_param;
FunctionType* func_getc_proto = FunctionType::get(Type::getInt32Ty(context), getc_param, false);
Function* func_getc = Function::Create(func_getc_proto, Function::ExternalLinkage, "getchar", jit_module);
Function *func = Function::Create(prototype, Function::ExternalLinkage, "bf_main", jit_module);
BasicBlock *begin = BasicBlock::Create(context, "begin", func);
BasicBlock *body = BasicBlock::Create(context, "body", func);
BasicBlock *end = BasicBlock::Create(context, "end", func);
builder.SetInsertPoint(begin);
builder.CreateBr(body);
builder.SetInsertPoint(end);
builder.CreateRetVoid();
builder.SetInsertPoint(body);
Value* ptr = &*(func->args().begin());
BasicBlock* stack[1024][3] = {{begin, body, end}};
int sp = 1;
for(;*program; program ++)
{
int delta = -1;
Value* temp = NULL;
switch(*program)
{
case '<':
delta = -delta;
case '>':
delta = -delta;
ptr = builder.CreateConstGEP1_32(ptr, delta);
break;
case '-':
delta = -delta;
case '+':
delta = -delta;
temp = builder.CreateLoad(ptr);
temp = builder.CreateAdd(temp, ConstantInt::get(Type::getInt8Ty(context), delta));
builder.CreateStore(temp, ptr);
break;
case '.':
temp = builder.CreateLoad(ptr);
builder.CreateCall(func_putc, temp);
break;
case ',':
temp = builder.CreateCall(func_getc);
temp = builder.CreateTrunc(temp, Type::getInt8Ty(context));
builder.CreateStore(temp, ptr);
break;
case '[':
stack[sp][0] = BasicBlock::Create(context, "loop_begin", func);
stack[sp][1] = BasicBlock::Create(context, "loop_body", func);
stack[sp][2] = BasicBlock::Create(context, "loop_end", func);
stack[sp-1][1] = BasicBlock::Create(context, "cont", func);
builder.CreateBr(stack[sp][0]);
builder.SetInsertPoint(stack[sp][2]);
builder.CreateBr(stack[sp-1][1]);
builder.SetInsertPoint(stack[sp][0]);
temp = builder.CreateLoad(ptr);
temp = builder.CreateICmpNE(temp, ConstantInt::get(Type::getInt8Ty(context), 0));
builder.CreateCondBr(temp, stack[sp][1], stack[sp][2]);
builder.SetInsertPoint(stack[sp][1]);
sp ++;
break;
case ']':
sp --;
builder.CreateBr(stack[sp][0]);
builder.SetInsertPoint(stack[sp-1][1]);
break;
}
}
builder.SetInsertPoint(stack[0][1]);
builder.CreateBr(stack[0][2]);
PassManagerBuilder pmb;
legacy::FunctionPassManager pass_manager(jit_module);
pmb.OptLevel = 3;
pmb.populateFunctionPassManager(pass_manager);
pass_manager.add(createInstructionCombiningPass());
pass_manager.add(createReassociatePass());
pass_manager.add(createGVNPass());
pass_manager.add(createCFGSimplificationPass());
pass_manager.add(createLoopUnrollPass());
pass_manager.doInitialization();
pass_manager.run(*func);
}
int main(int argc, char** argv)
{
// Initialization
LLVMInitializeNativeTarget();
InitializeNativeTargetAsmPrinter();
InitializeNativeTargetAsmParser();
TargetMachine* target = EngineBuilder().selectTarget();
// Emit the LLVM IR to the Module
//code_gen("++++++++++[>+++++++>++++++++++>+++>+<<<<-]>++.>+.+++++++..+++.>++.<<+++++++++++++++.>.+++.------.--------.>+.>.");
//code_gen(">++++[-<+++++++++++>]>-[[->+>+<<]>[-<+>]+>-]<-<-<[-<[[->>+>>+<<<<]>>>>[>>>>++++++++++<<<<[->+>>+>-[<-]<[->>+<<<<[->>>+<<<]>]<<]>+[-<+>]>>>[-]>[-<<<<+>>>>]<<<<]<[>++++++[<++++++++>-]<-.[-]<]<<<<[<<]<.>>>[>>]>[->+>>+<<<]>[-<+>]<<<<[-<[->>+<<]>>>+[>>]+<-[>-]>[-<<[<<]>[-]>[>>]>>[-<+<<+>>>]<[->+<]]<<<[<<]<<]>>>>[<[-<<+>>]<+>>->>]<[-]>>>[-]<<<<<]]u");
//code_gen("[>[>>+>+<<<-]>>>[<<<+>>>-]<[<+>-]<<<-]");
char buf[65536];
FILE* fp = fopen(argv[1], "r");
int n = fread(buf, 1, sizeof(buf), fp);
fclose(fp);
buf[n] = 0;
code_gen(buf);
// Compile the IR to Machine Code
const DataLayout dl = target->createDataLayout();
LinkLayer link_layer([]() { return std::make_shared<SectionMemoryManager>(); });
CompileLayer compile_layer(link_layer, Compiler(*target));
auto jit_module_handle = cantFail(compile_layer.addModule(std::shared_ptr<Module>(jit_module), orc::createLambdaResolver(dummy_lookup, dummy_lookup)));
// Run the compiled function !
JITSymbol symbol = compile_layer.findSymbolIn(jit_module_handle, "bf_main", false);
void (*native_func)(char*) = (decltype(native_func))cantFail(symbol.getAddress());
char mem[65536] = {};
native_func(mem);
delete target;
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment