Skip to content

Instantly share code, notes, and snippets.

@Mytherin
Created July 5, 2016 11:50
Show Gist options
  • Save Mytherin/ccd6dd333258bdf409f9525a7a35be36 to your computer and use it in GitHub Desktop.
Save Mytherin/ccd6dd333258bdf409f9525a7a35be36 to your computer and use it in GitHub Desktop.
JIT Compile a For loop using the LLVM 3.9 C API
/**
* LLVM equivalent of:
*
* void loop(double *result, double *a, double *b, size_t length) {
* for(size_t i = 0; i < length; i++) {
* result[i] = a[i] * b[i];
* }
* }
*/
#include <llvm-c/Core.h>
#include <llvm-c/ExecutionEngine.h>
#include <llvm-c/Target.h>
#include <llvm-c/Analysis.h>
#include <llvm-c/BitWriter.h>
#include <inttypes.h>
#include <stdio.h>
#include <stdlib.h>
// compile: clang `llvm-config --cflags` -c llvmtest.c -o llvmtest.o
// link: clang++ -std=c++11 `llvm-config --cxxflags --ldflags --libs --system-libs` -lffi llvmtest.o -o llvmtest
// both: clang `llvm-config --cflags` -c llvmtest.c -o llvmtest.o && clang++ -std=c++11 `llvm-config --cxxflags --ldflags --libs --system-libs` -lffi llvmtest.o -o llvmtest
static void print_array(const char *name, double* ptr, size_t elements) {
printf("%s: [", name);
for(size_t i = 0; i < elements; i++) {
printf("%lf", ptr[i]);
if (i != elements - 1) {
printf(", ");
}
}
printf("]\n");
}
int main(int argc, char const *argv[]) {
size_t i;
LLVMInitializeNativeTarget();
LLVMInitializeAllTargetMCs();
LLVMInitializeAllAsmPrinters();
LLVMInitializeAllAsmParsers();
LLVMContextRef context = LLVMContextCreate();
LLVMModuleRef module = LLVMModuleCreateWithName("LoopModule");
LLVMTypeRef doubleType = LLVMDoubleTypeInContext(context);
LLVMTypeRef doubleptrType = LLVMPointerType(doubleType, 0);
LLVMTypeRef int64Type = LLVMInt64TypeInContext(context);
LLVMValueRef constant_zero = LLVMConstInt(int64Type, 0, 1);
LLVMValueRef constant_one = LLVMConstInt(int64Type, 1, 1);
size_t argcount = 4;
LLVMTypeRef param_types[] = { doubleptrType, doubleptrType, doubleptrType, int64Type};
LLVMTypeRef prototype = LLVMFunctionType(LLVMVoidType(), param_types, argcount, 0);
LLVMValueRef function = LLVMAddFunction(module, "loop", prototype);
LLVMBuilderRef builder = LLVMCreateBuilderInContext(context);
LLVMBasicBlockRef entry = LLVMAppendBasicBlock(function, "entry");
LLVMBasicBlockRef condition = LLVMAppendBasicBlock(function, "condition");
LLVMBasicBlockRef body = LLVMAppendBasicBlock(function, "body");
LLVMBasicBlockRef increment = LLVMAppendBasicBlock(function, "increment");
LLVMBasicBlockRef end = LLVMAppendBasicBlock(function, "end");
LLVMValueRef index_addr;
LLVMPositionBuilderAtEnd(builder, entry);
{
index_addr = LLVMBuildAlloca(builder, int64Type, "index");
LLVMBuildStore(builder, constant_zero, index_addr);
LLVMBuildBr(builder, condition);
}
LLVMPositionBuilderAtEnd(builder, condition);
{
LLVMValueRef index = LLVMBuildLoad(builder, index_addr, "[index]");
LLVMValueRef cond = LLVMBuildICmp(builder, LLVMIntSLT, index, LLVMGetParam(function, 3), "index < size");
LLVMBuildCondBr(builder, cond, body, end);
}
LLVMPositionBuilderAtEnd(builder, body);
{
LLVMValueRef index = LLVMBuildLoad(builder, index_addr, "[index]");
LLVMValueRef x_addr = LLVMBuildGEP(builder, LLVMGetParam(function, 1), &index, 1, "x[index]");
LLVMValueRef xindex = LLVMBuildLoad(builder, x_addr, "x[index]");
LLVMValueRef y_addr = LLVMBuildGEP(builder, LLVMGetParam(function, 2), &index, 1, "y[index]");
LLVMValueRef yindex = LLVMBuildLoad(builder, y_addr, "y[index]");
LLVMValueRef xmuly = LLVMBuildFMul(builder, xindex, yindex, "x[result] * y[result]");
LLVMValueRef result_addr = LLVMBuildGEP(builder, LLVMGetParam(function, 0), &index, 1, "result[index]");
LLVMBuildStore(builder, xmuly, result_addr);
LLVMBuildBr(builder, increment);
}
LLVMPositionBuilderAtEnd(builder, increment);
{
LLVMValueRef index = LLVMBuildLoad(builder, index_addr, "[index]");
LLVMValueRef indexpp = LLVMBuildAdd(builder, index, constant_one, "index++");
LLVMBuildStore(builder, indexpp, index_addr);
LLVMBuildBr(builder, condition);
}
LLVMPositionBuilderAtEnd(builder, end);
{
LLVMBuildRetVoid(builder);
}
char *error = NULL;
LLVMVerifyModule(module, LLVMAbortProcessAction, &error);
LLVMDisposeMessage(error);
LLVMLinkInInterpreter();
LLVMExecutionEngineRef engine;
error = NULL;
if (LLVMCreateExecutionEngineForModule(&engine, module, &error) != 0) {
fprintf(stderr, "failed to create execution engine\n");
abort();
}
if (error) {
fprintf(stderr, "error: %s\n", error);
LLVMDisposeMessage(error);
exit(EXIT_FAILURE);
}
size_t elements = 5;
double *result = malloc(sizeof(double) * elements);
double *x = malloc(sizeof(double) * elements);
double *y = malloc(sizeof(double) * elements);
for(i = 0; i < elements; i++) {
x[i] = i;
y[i] = 10 - i;
}
LLVMGenericValueRef args[] = {
LLVMCreateGenericValueOfPointer(result),
LLVMCreateGenericValueOfPointer(x),
LLVMCreateGenericValueOfPointer(y),
LLVMCreateGenericValueOfInt(int64Type, elements, 0)
};
LLVMRunFunction(engine, function, argcount, args);
print_array("x", x, 5);
print_array("y", y, 5);
print_array("result", result, 5);
LLVMDisposeBuilder(builder);
LLVMDisposeExecutionEngine(engine);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment