Last active
November 15, 2023 13:20
-
-
Save skeeto/5df632bad47bd71f0034d5683e26c998 to your computer and use it in GitHub Desktop.
Partial application JIT demo, with arena
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <stdint.h> | |
#include <stdio.h> | |
#include <string.h> | |
#include <windows.h> | |
#define new(a, t, n) (t *)alloc(a, sizeof(t)*n) | |
typedef struct { | |
char *beg, *end; | |
} arena; | |
static void *alloc(arena *a, ptrdiff_t size) | |
{ | |
if (a->end-a->beg < size) { | |
*(volatile int *)0 = 0; | |
} | |
return memset(a->end -= size, 0, size); | |
} | |
// Create an arena of executable memory for creating functions. | |
static arena newjitarena(ptrdiff_t size) | |
{ | |
arena a = {0}; | |
int type = MEM_RESERVE|MEM_COMMIT; | |
a.beg = VirtualAlloc(0, size, type, PAGE_EXECUTE_READWRITE); | |
a.end = a.beg ? a.beg + size : 0; | |
return a; | |
} | |
// Partially-apply a 1-arg function into a 0-arg function. | |
static void *partial_1to0(void *target, uintptr_t arg, arena *jit) | |
{ | |
unsigned char *f = new(jit, unsigned char, 24); | |
unsigned char *p = f; | |
*p++ = 0x48; // mov arg, %rcx | |
*p++ = 0xb9; // " | |
*p++ = arg>> 0; *p++ = arg>> 8; *p++ = arg>>16; *p++ = arg>>24; | |
*p++ = arg>>32; *p++ = arg>>40; *p++ = arg>>48; *p++ = arg>>56; | |
*p++ = 0x48; // mov dst, %rax | |
*p++ = 0xb8; // " | |
uintptr_t dst = (uintptr_t)target; | |
*p++ = dst>> 0; *p++ = dst>> 8; *p++ = dst>>16; *p++ = dst>>24; | |
*p++ = dst>>32; *p++ = dst>>40; *p++ = dst>>48; *p++ = dst>>56; | |
*p++ = 0xff; // jmp *%rax | |
*p++ = 0xe0; // " | |
return f; | |
} | |
// Partial-left-apply a 2-arg function into a 1-arg function. | |
static void *partial_left2to1(void *target, uintptr_t arg, arena *jit) | |
{ | |
unsigned char *f = new(jit, unsigned char, 32); | |
unsigned char *p = f; | |
*p++ = 0x48; // mov %rcx, %rdx | |
*p++ = 0x89; // " | |
*p++ = 0xca; // " | |
*p++ = 0x48; // mov arg, %rcx | |
*p++ = 0xba; // " | |
*p++ = arg>> 0; *p++ = arg>> 8; *p++ = arg>>16; *p++ = arg>>24; | |
*p++ = arg>>32; *p++ = arg>>40; *p++ = arg>>48; *p++ = arg>>56; | |
*p++ = 0x48; // mov dst, %rax | |
*p++ = 0xb8; // " | |
uintptr_t dst = (uintptr_t)target; | |
*p++ = dst>> 0; *p++ = dst>> 8; *p++ = dst>>16; *p++ = dst>>24; | |
*p++ = dst>>32; *p++ = dst>>40; *p++ = dst>>48; *p++ = dst>>56; | |
*p++ = 0xff; // jmp *%rax | |
*p++ = 0xe0; // " | |
return f; | |
} | |
// Partial-right-apply a 3-arg function into a 2-arg function. | |
static void *partial_right3to2(void *target, uintptr_t arg, arena *jit) | |
{ | |
unsigned char *f = new(jit, unsigned char, 24); | |
unsigned char *p = f; | |
*p++ = 0x49; // mov arg, %r8 | |
*p++ = 0xb8; // " | |
*p++ = arg>> 0; *p++ = arg>> 8; *p++ = arg>>16; *p++ = arg>>24; | |
*p++ = arg>>32; *p++ = arg>>40; *p++ = arg>>48; *p++ = arg>>56; | |
*p++ = 0x48; // mov dst, %rax | |
*p++ = 0xb8; // " | |
uintptr_t dst = (uintptr_t)target; | |
*p++ = dst>> 0; *p++ = dst>> 8; *p++ = dst>>16; *p++ = dst>>24; | |
*p++ = dst>>32; *p++ = dst>>40; *p++ = dst>>48; *p++ = dst>>56; | |
*p++ = 0xff; // jmp *%rax | |
*p++ = 0xe0; // " | |
return f; | |
} | |
// Demonstration of partial_1to0 | |
typedef int func0(void); | |
static int square(int x) | |
{ | |
return x * x; | |
} | |
// Create an array of functions returning the square of their index. | |
static func0 **gensquarers(int n, arena *jit) | |
{ | |
func0 **funcs = new(jit, func0 *, n); | |
for (int i = 0; i < n; i++) { | |
funcs[i] = partial_1to0(square, i, jit); | |
} | |
return funcs; | |
} | |
static void demo1(arena scratch, int n) | |
{ | |
func0 **square = gensquarers(n, &scratch); | |
for (int i = 0; i < n; i++) { | |
printf("square[%d]() = %d\n", i, square[i]()); | |
} | |
} | |
// Demonstration of partial_left2to1 | |
typedef int func1(int); | |
static int add(int x, int y) | |
{ | |
return x + y; | |
} | |
// Create an array of functions that bias by their index. | |
static func1 **genbiasers(int n, arena *jit) | |
{ | |
func1 **funcs = new(jit, func1 *, n); | |
for (int i = 0; i < n; i++) { | |
funcs[i] = partial_left2to1(add, i, jit); | |
} | |
return funcs; | |
} | |
void demo2(arena scratch, int n) | |
{ | |
func1 **bias = genbiasers(n, &scratch); | |
for (int i = 0; i < n; i++) { | |
printf("bias[%d](10) = %d\n", i, bias[i](10)); | |
} | |
} | |
// Demonstration of partial_right3to2 | |
typedef enum {ASCEND, DESCEND} sortdir; | |
static int intcmp(int *a, int *b, sortdir dir) | |
{ | |
return dir ? *a - *b : *b - *a; | |
} | |
static void demo3(arena scratch) | |
{ | |
int array[] = {4, 1, 3, 2}; | |
typedef int (*qsortcmp)(const void *, const void *); | |
qsortcmp descend = partial_right3to2(intcmp, ASCEND, &scratch); | |
qsortcmp ascend = partial_right3to2(intcmp, DESCEND, &scratch); | |
qsort(array, 4, sizeof(*array), ascend); | |
printf("%d %d %d %d\n", array[0], array[1], array[2], array[3]); | |
qsort(array, 4, sizeof(*array), descend); | |
printf("%d %d %d %d\n", array[0], array[1], array[2], array[3]); | |
} | |
int main(void) | |
{ | |
arena jit = newjitarena(1<<21); | |
demo1(jit, 6); | |
demo2(jit, 6); | |
demo3(jit); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment