Last active
July 29, 2024 22:27
-
-
Save RealNeGate/3261bb54c21e0a0bade07b7bee6bd80d to your computer and use it in GitHub Desktop.
Learning Hindley-Milner
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
X(fn) | |
X(do) | |
X(let) | |
X(new) | |
X(while) | |
X(return) | |
#undef X |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <stdio.h> | |
#include <stdlib.h> | |
#include <stdint.h> | |
#include <stdarg.h> | |
#include <string.h> | |
#include <assert.h> | |
#include <stdbool.h> | |
#include <sys/stat.h> | |
#include <inttypes.h> | |
#ifdef _WIN32 | |
#define fileno _fileno | |
#define fstat _fstat | |
#define stat _stat | |
#endif | |
#define FOR_N(i, start, end) for (ptrdiff_t i = (start), _end_ = (end); i < _end_; i++) | |
#define FOR_REV_N(i, start, end) for (ptrdiff_t i = (end), _start_ = (start); i-- > _start_;) | |
typedef struct { | |
int len; | |
char data[]; | |
} Intern; | |
static char* read_entire_file(const char* filepath) { | |
FILE* file; | |
assert(fopen_s(&file, filepath, "rb") == 0); | |
int descriptor = fileno(file); | |
struct stat file_stats; | |
if (fstat(descriptor, &file_stats) == -1) return NULL; | |
int length = file_stats.st_size; | |
char* data = malloc(length + 1); | |
fseek(file, 0, SEEK_SET); | |
size_t length_read = fread(data, 1, length, file); | |
data[length_read] = '\0'; | |
fclose(file); | |
return data; | |
} | |
// murmur3 32-bit | |
uint32_t mur3_32(const void *key, int len, uint32_t h) { | |
// main body, work on 32-bit blocks at a time | |
for (int i=0;i<len/4;i++) { | |
uint32_t k = ((uint32_t*) key)[i]*0xcc9e2d51; | |
k = ((k << 15) | (k >> 17))*0x1b873593; | |
h = (((h^k) << 13) | ((h^k) >> 19))*5 + 0xe6546b64; | |
} | |
// load/mix up to 3 remaining tail bytes into a tail block | |
uint32_t t = 0; | |
uint8_t *tail = ((uint8_t*) key) + 4*(len/4); | |
switch(len & 3) { | |
case 3: t ^= tail[2] << 16; | |
case 2: t ^= tail[1] << 8; | |
case 1: { | |
t ^= tail[0] << 0; | |
h ^= ((0xcc9e2d51*t << 15) | (0xcc9e2d51*t >> 17))*0x1b873593; | |
} | |
} | |
// finalization mix, including key length | |
h = ((h^len) ^ ((h^len) >> 16))*0x85ebca6b; | |
h = (h ^ (h >> 13))*0xc2b2ae35; | |
return h ^ (h >> 16); | |
} | |
//////////////////////////////// | |
// Type system | |
//////////////////////////////// | |
typedef struct Type Type; | |
typedef struct { | |
Intern* name; | |
Type* type; | |
// data layout | |
} TableEntry; | |
struct Type { | |
enum { | |
TYPE_VAR, | |
// monotypes | |
TYPE_VOID, | |
TYPE_INT, | |
TYPE_FLT, | |
TYPE_INTFLT, | |
TYPE_FUNC, | |
TYPE_TUPLE, | |
TYPE_ARRAY, | |
TYPE_TABLE, | |
} tag; | |
// helpful for debugging & worklists | |
int uid; | |
// disjoint-set UF | |
Type* parent; | |
union { | |
struct { | |
// tuple types | |
int count; | |
Type** elems; | |
} tuple; | |
struct { | |
Type* args; | |
Type* ret; | |
} fn; | |
struct { | |
// sorted by address | |
int cap, count; | |
TableEntry* elems; | |
} table; | |
Type* array_elem; | |
}; | |
}; | |
// UF find | |
static Type* type_find(Type* a) { | |
// leader | |
Type* l = a; | |
while (l->parent != l) { l = l->parent; } | |
// path compaction | |
while (a->parent != a) { | |
Type* p = a->parent; | |
a->parent = l, a = p; | |
} | |
return l; | |
} | |
// UF union, returns true for progress | |
static bool type_union(Type* a, Type* b) { | |
a = type_find(a); | |
b = type_find(b); | |
if (a == b) { | |
return false; | |
} | |
a->parent = b; | |
return true; | |
} | |
// you don't actually need to use the return, either a or b will | |
// become "equivalent" to it, it's merely the current leader in the | |
// disjoint set | |
static bool type_unify(Type* a, Type* b) { | |
printf("UNIFY T%d = T%d\n", a->uid, b->uid); | |
a = type_find(a); | |
b = type_find(b); | |
if (a == b) { | |
// already matching | |
return false; | |
} else if (a->tag == TYPE_VAR && b->tag == TYPE_VAR) { | |
// both are type vars? well at least we know they're the same | |
return type_union(a, b); | |
} else if (a->tag == TYPE_VAR || b->tag == TYPE_VAR) { | |
// only one is a type var, the other is a resolved or semi-resolved | |
// type which is what should be made to flow up. | |
if (a->tag != TYPE_VAR) { | |
return type_union(b, a); | |
} else { | |
return type_union(a, b); | |
} | |
} else if (a->tag == b->tag) { | |
bool progress = false; | |
if (a->tag == TYPE_TUPLE) { | |
assert(a->tuple.count == b->tuple.count); | |
FOR_N(i, 0, a->tuple.count) { | |
progress |= type_unify(a->tuple.elems[i], b->tuple.elems[i]); | |
} | |
} else if (a->tag == TYPE_FUNC) { | |
progress |= type_unify(a->fn.args, b->fn.args); | |
progress |= type_unify(a->fn.ret, b->fn.ret); | |
} else { | |
progress |= type_union(a, b); | |
} | |
return progress; | |
} else { | |
if (a->tag > b->tag) { | |
Type* tmp = a; | |
a = b; | |
b = tmp; | |
} | |
// void & T = T | |
if (a->tag == TYPE_VOID) { | |
return type_union(a, b); | |
} | |
// int & flt = flt | |
if (a->tag == TYPE_FLT && b->tag == TYPE_INTFLT) { | |
return type_union(b, a); | |
} | |
__debugbreak(); | |
return false; | |
} | |
} | |
static int TYPE_CNT = 0; | |
static Type* type_clone(Type* src) { | |
Type* t = malloc(sizeof(Type)); | |
*t = *src; | |
t->parent = t; | |
t->uid = ++TYPE_CNT; | |
return t; | |
} | |
static Type* type_new_mono(int tag) { | |
Type* t = malloc(sizeof(Type)); | |
*t = (Type){ .tag = tag, .uid = ++TYPE_CNT, .parent = t }; | |
return t; | |
} | |
static Type* type_new_var(void) { | |
Type* t = malloc(sizeof(Type)); | |
*t = (Type){ .tag = TYPE_VAR, .uid = ++TYPE_CNT, .parent = t }; | |
return t; | |
} | |
static Type* type_new_array(Type* elem) { | |
Type* t = malloc(sizeof(Type)); | |
*t = (Type){ .tag = TYPE_ARRAY, .uid = ++TYPE_CNT, .parent = t, .array_elem = elem }; | |
return t; | |
} | |
static Type* type_new_table(Type* elem) { | |
Type* t = malloc(sizeof(Type)); | |
*t = (Type){ .tag = TYPE_TABLE, .uid = ++TYPE_CNT, .parent = t, .table = { 4 } }; | |
t->table.elems = malloc(t->table.cap * sizeof(TableEntry)); | |
return t; | |
} | |
static Type* type_new_tuple(int arg_count) { | |
Type* t = malloc(sizeof(Type)); | |
*t = (Type){ .tag = TYPE_TUPLE, .uid = ++TYPE_CNT, .parent = t, .tuple = { arg_count, malloc(arg_count * sizeof(Type*)) } }; | |
return t; | |
} | |
static Type* type_new_func(Type* args, Type* ret) { | |
Type* t = malloc(sizeof(Type)); | |
*t = (Type){ .tag = TYPE_FUNC, .uid = ++TYPE_CNT, .parent = t, .fn = { args, ret } }; | |
return t; | |
} | |
//////////////////////////////// | |
// Parse IR | |
//////////////////////////////// | |
typedef struct Node Node; | |
struct Node { | |
enum { | |
NODE_NULL, | |
NODE_LEN, | |
NODE_INT, | |
NODE_REAL, | |
NODE_DECL, | |
NODE_COMPOUND, | |
// core | |
NODE_BINOP, | |
NODE_APPLY, | |
NODE_SYMBOL, | |
NODE_LAMBDA, | |
NODE_TERNARY, | |
// tuples | |
NODE_PROJ, | |
NODE_TUPLE, | |
// tables | |
NODE_ACCESS, | |
// array | |
NODE_SUBSCRIPT, | |
NODE_ALLOC_ARR, | |
} tag; | |
Type* type; | |
union { | |
struct { | |
int arg_count; | |
Node* body; | |
} fn; | |
struct { | |
// callsite type | |
Type* site; | |
} apply; | |
struct { | |
Node* prev; | |
Intern* name; | |
int param; | |
} decl; | |
Intern* access; | |
int proj_i; | |
uint64_t num; | |
double flt; | |
int binop; | |
}; | |
// helpful for debugging & worklists | |
int uid; | |
// unordered list of uses | |
int use_cnt, use_cap; | |
Node** uses; | |
// not SSA, just a semi-stupid parse graph for inference | |
int in_cnt; | |
Node* ins[]; | |
}; | |
static int NODE_CNT; | |
static void set_in(Node* n, Node* v, int i) { | |
Node* old = n->ins[i]; | |
n->ins[i] = v; | |
if (old == v) { return; } | |
if (old != NULL) { | |
// remove old user (if there's no other uses by this node) | |
bool found = false; | |
FOR_N(j, 0, n->in_cnt) if (n->ins[j] == old) { | |
found = true; | |
break; | |
} | |
if (!found) { | |
FOR_N(j, 0, old->use_cnt) { | |
if (old->uses[j] == n) { | |
old->use_cnt -= 1; | |
old->uses[j] = old->uses[old->use_cnt]; | |
break; | |
} | |
} | |
} | |
} | |
if (v != NULL) { | |
// if it's already in the list, don't add it | |
FOR_N(j, 0, v->use_cnt) { | |
if (v->uses[j] == n) { return; } | |
} | |
// add new user | |
if (v->use_cnt == v->use_cap) { | |
v->use_cap *= 2; | |
v->uses = realloc(v->uses, v->use_cap * sizeof(Node*)); | |
} | |
v->uses[v->use_cnt++] = n; | |
} | |
} | |
static Node* new_node(int tag, Type* type, int in_cnt, ...) { | |
Node* n = malloc(sizeof(Node) + in_cnt*sizeof(Node*)); | |
*n = (Node){ tag, type, .uid = ++NODE_CNT, .in_cnt = in_cnt, .use_cap = 2, .uses = malloc(2 * sizeof(Node*)) }; | |
va_list args; | |
va_start(args, in_cnt); | |
FOR_N(i, 0, in_cnt) { | |
n->ins[i] = NULL; | |
set_in(n, va_arg(args, Node*), i); | |
} | |
va_end(args); | |
return n; | |
} | |
static Node* new_node2(int tag, Type* type, int in_cnt) { | |
Node* n = malloc(sizeof(Node) + in_cnt*sizeof(Node*)); | |
*n = (Node){ tag, type, .uid = ++NODE_CNT, .in_cnt = in_cnt, .use_cap = 2, .uses = malloc(2 * sizeof(Node*)) }; | |
FOR_N(i, 0, in_cnt) { n->ins[i] = NULL; } | |
return n; | |
} | |
static Node* new_proj(Type* type, Node* n, int idx) { | |
Node* proj = new_node2(NODE_PROJ, type, 1); | |
proj->proj_i = idx; | |
set_in(proj, n, 0); | |
return proj; | |
} | |
//////////////////////////////// | |
// Parser | |
//////////////////////////////// | |
// TODO(NeGate): fill in good errors | |
static bool kw_init; | |
#define X(name) static Intern* KW_ ## name; | |
#include "keywords.inc" | |
typedef struct { | |
int visited_cap; | |
uint32_t* visited; | |
int cnt, cap; | |
Node** arr; | |
} Worklist; | |
static void ws_init(Worklist* ws, int cap) { | |
cap = 1ull << (64 - __builtin_clzll(cap - 1)); | |
ws->visited_cap = cap / 32; | |
ws->visited = calloc(cap, sizeof(uint32_t)); | |
ws->cnt = 0; | |
ws->cap = cap; | |
ws->arr = malloc(cap * sizeof(Node*)); | |
} | |
static void ws_push(Worklist* ws, Node* n) { | |
size_t word_i = n->uid / 32u; | |
if (word_i >= ws->visited_cap) { | |
size_t new_cap = 1ull << (64 - __builtin_clzll(word_i - 1)); | |
ws->visited = realloc(ws->visited, new_cap*sizeof(uint32_t)); | |
FOR_N(i, ws->visited_cap, new_cap) { ws->visited[i] = 0; } | |
ws->visited_cap = new_cap; | |
} else if (ws->visited[word_i] & (1u << (n->uid % 32u))) { | |
return; | |
} | |
ws->visited[word_i] |= 1u << (n->uid % 32u); | |
if (n->uid >= ws->cap) { | |
size_t new_cap = 1ull << (64 - __builtin_clzll(n->uid - 1)); | |
ws->arr = realloc(ws->arr, new_cap*sizeof(Node*)); | |
ws->cap = new_cap; | |
} | |
ws->arr[ws->cnt++] = n; | |
} | |
static Node* ws_pop(Worklist* ws) { | |
if (ws->cnt == 0) { | |
return NULL; | |
} | |
Node* n = ws->arr[--ws->cnt]; | |
ws->visited[n->uid] = false; | |
return n; | |
} | |
typedef struct { | |
const char* src; | |
const char* prev; | |
const char* curr; | |
// current token | |
struct { | |
enum { | |
TKN_EOF = 0, | |
TKN_IDENT = 128, TKN_INT, TKN_REAL, | |
// operators with an equals following them get +256 | |
TKN_EQ = 256 + '=', | |
TKN_NE = 256 + '!', | |
TKN_GE = 256 + '>', | |
TKN_LE = 256 + '<', | |
} type; | |
Intern* str; // IDENT | |
uint64_t num; // INT | |
double flt; // REAL | |
} token; | |
// TODO(NeGate): make proper hashtables | |
Node* symtab; | |
Node* top_fn; | |
} Parser; | |
// 0.123 | |
// 0.1 | |
static bool num(int ch) { return ch >= '0' && ch <= '9'; } | |
// ident0 is the starting char | |
static bool ident0(int ch) { return (ch >= 'A' && ch <= 'Z') || (ch >= 'a' && ch <= 'z'); } | |
static bool ident1(int ch) { return ident0(ch) || num(ch); } | |
// TODO(NeGate): it's tested last but we really should define the sigils | |
static bool sigil(int ch) { return true; } | |
static bool space(int ch) { return ch == ' ' || ch == '\t' || ch == '\r' || ch == '\n'; } | |
// We intern all identifiers | |
static Intern* parser_intern(int len, const char* start) { | |
// TODO(NeGate): this prolly shouldn't be fixed size | |
static Intern* table[1024]; | |
uint32_t hash = mur3_32(start, len, 0); | |
uint32_t first = hash & 1023, i = first; | |
do { | |
if (table[i] == NULL) { | |
table[i] = malloc(sizeof(Intern) + len + 1); | |
table[i]->len = len; | |
memcpy(table[i]->data, start, len); | |
table[i]->data[len] = 0; | |
return table[i]; | |
} else if (table[i]->len == len && memcmp(table[i]->data, start, len) == 0) { | |
return table[i]; | |
} | |
i = (i + 1) & 1023; | |
} while (i != first); | |
abort(); | |
} | |
static bool parser_tokeq(Parser* restrict p, Intern* other) { | |
assert(p->token.str && p->token.type == TKN_IDENT); | |
return p->token.str == other; | |
} | |
static bool parser_token(Parser* restrict p) { | |
// skip ws & comments | |
loop: { | |
if (p->curr[0] == '/' && p->curr[1] == '/') { | |
do { p->curr++; } while (*p->curr && *p->curr != '\n'); | |
goto loop; | |
} | |
if (space(*p->curr)) { | |
do { p->curr++; } while (space(*p->curr)); | |
goto loop; | |
} | |
} | |
// eof | |
if (*p->curr == 0) { | |
p->token.type = TKN_EOF; | |
return false; | |
} | |
const char* start = p->curr; | |
int ch = *p->curr; | |
p->prev = start; | |
p->token.str = NULL; | |
if (num(ch)) { | |
unsigned long long n = 0; | |
do { n = (n * 10) + (ch - '0'), p->curr++, ch = *p->curr; } while (num(ch)); | |
if (*p->curr == '.') { | |
p->curr += 1; | |
double m = 0, f = 0.1; | |
while (num(*p->curr)) { | |
m += (*p->curr - '0')*f, f *= 0.1; | |
p->curr++; | |
} | |
p->token.type = TKN_REAL; | |
p->token.flt = n+m; | |
} else { | |
p->token.type = TKN_INT; | |
p->token.num = n; | |
} | |
} else if (ident0(ch)) { | |
do { p->curr++, ch = *p->curr; } while (ident1(ch)); | |
p->token.type = TKN_IDENT; | |
p->token.str = parser_intern(p->curr - start, start); | |
} else if (sigil(ch)) { | |
p->token.type = ch; | |
p->curr++; | |
if ((ch == '=' || ch == '>' || ch == '<' || ch == '!') && *p->curr == '=') { | |
p->token.type += 256; | |
p->curr += 1; | |
} | |
} else { | |
abort(); | |
} | |
return true; | |
} | |
static bool parser_try_eatid(Parser* restrict p, Intern* other) { | |
if (p->token.type == TKN_IDENT && p->token.str == other) { | |
parser_token(p); | |
return true; | |
} | |
return false; | |
} | |
static bool parser_try_eat(Parser* restrict p, char ch) { | |
if (p->token.type == ch) { | |
parser_token(p); | |
return true; | |
} | |
return false; | |
} | |
static void parser_eat(Parser* restrict p, char ch) { | |
if (p->token.type != ch) { | |
abort(); | |
} | |
parser_token(p); | |
} | |
static Intern* parser_eat_ident(Parser* restrict p) { | |
if (p->token.type != TKN_IDENT) { | |
abort(); | |
} | |
Intern* str = p->token.str; | |
parser_token(p); | |
return str; | |
} | |
static int get_binop(int token_type) { | |
switch (token_type) { | |
case TKN_EQ: return 3; | |
case TKN_NE: return 3; | |
case TKN_GE: return 3; | |
case TKN_LE: return 3; | |
case '>': return 3; | |
case '<': return 3; | |
case '*': return 2; | |
case '/': return 2; | |
case '%': return 2; | |
case '+': return 1; | |
case '-': return 1; | |
default: return 0; | |
} | |
} | |
static Node* parser_push_decl(Parser* restrict p, Intern* name, Type* type, Node* init) { | |
Node* n = new_node2(NODE_DECL, type, 1); | |
n->decl.prev = p->symtab; | |
n->decl.name = name; | |
n->decl.param = -1; | |
if (name != NULL) { | |
p->symtab = n; | |
} | |
set_in(n, init, 0); | |
return n; | |
} | |
static Node* parser_lookup_decl(Parser* restrict p, Intern* name) { | |
Node* top = p->symtab; | |
while (top != NULL) { | |
assert(top->tag == NODE_DECL); | |
if (top->decl.name == name) { | |
return top; | |
} | |
top = top->decl.prev; | |
} | |
return NULL; | |
} | |
static Node* parse_expr(Parser* restrict p); | |
static Node* parse_ternary(Parser* restrict p); | |
static Node* parse_tuple(Parser* restrict p) { | |
if (!parser_try_eat(p, '(')) { | |
return NULL; | |
} | |
// tmp array should probably be resizable | |
int arg_cnt = 0; | |
Node* args[16]; | |
while (p->token.type != ')' && p->token.type != TKN_EOF) { | |
args[arg_cnt++] = parse_expr(p); | |
if (!parser_try_eat(p, ',')) { break; } | |
} | |
parser_eat(p, ')'); | |
assert(arg_cnt != 0); | |
if (arg_cnt == 1) { | |
return args[0]; | |
} | |
// create placeholder type for application | |
Type* type = type_new_tuple(arg_cnt); | |
FOR_N(i, 0, arg_cnt) { | |
type->tuple.elems[i] = type_new_var(); | |
type_unify(args[i]->type, type->tuple.elems[i]); | |
} | |
Node* n = new_node2(NODE_TUPLE, type, arg_cnt); | |
FOR_N(i, 0, arg_cnt) { | |
set_in(n, args[i], i); | |
} | |
return n; | |
} | |
// TUPLE ::= '(' EXPR (',' EXPR)* ')' | |
// ATOM ::= '#'? (IDENT | NUMBER | TUPLE) ('.' IDENT)* TUPLE? | |
static Node* parse_atom(Parser* restrict p) { | |
if (p->token.type == '#') { | |
parser_token(p); | |
Node* n = parse_atom(p); | |
return new_node(NODE_LEN, type_new_mono(TYPE_INT), 1, n); | |
} | |
Node* n = NULL; | |
if (p->token.type == TKN_IDENT) { | |
Node* sym = parser_lookup_decl(p, p->token.str); | |
if (sym == NULL) { | |
// parser_err(); | |
abort(); | |
} | |
parser_token(p); | |
n = new_node(NODE_SYMBOL, sym->type, 1, sym); | |
} else if (p->token.type == TKN_INT) { | |
n = new_node(NODE_INT, type_new_mono(TYPE_INTFLT), 0); | |
n->num = p->token.num; | |
parser_token(p); | |
} else if (p->token.type == TKN_REAL) { | |
n = new_node(NODE_REAL, type_new_mono(TYPE_FLT), 0); | |
n->flt = p->token.flt; | |
parser_token(p); | |
} else if (p->token.type == '(') { | |
n = parse_tuple(p); | |
} | |
loop: { | |
if (parser_try_eat(p, '.')) { | |
// table access | |
Intern* name = parser_eat_ident(p); | |
n = new_node(NODE_ACCESS, type_new_var(), 1, n); | |
n->access = name; | |
goto loop; | |
} else if (parser_try_eat(p, '[')) { | |
// array subscript | |
Node* idx = parse_ternary(p); | |
parser_eat(p, ']'); | |
Type* t = type_new_var(); | |
n = new_node(NODE_SUBSCRIPT, t, 2, n, idx); | |
goto loop; | |
} else if (p->token.type == '(') { | |
Node* args = parse_tuple(p); | |
Type* type = type_new_func(args->type, type_new_var()); | |
Node* target = n; | |
n = new_node2(NODE_APPLY, type->fn.ret, 2); | |
n->apply.site = type; | |
set_in(n, target, 0); | |
set_in(n, args, 1); | |
goto loop; | |
} | |
} | |
return n; | |
} | |
static Node* parse_binop(Parser* restrict p, int min_prec) { | |
Node* lhs = parse_atom(p); | |
int prec; | |
while (prec = get_binop(p->token.type), prec != 0 && prec >= min_prec) { | |
int op = p->token.type; | |
parser_token(p); | |
Node* rhs = parse_binop(p, prec + 1); | |
type_unify(lhs->type, rhs->type); | |
lhs = new_node(NODE_BINOP, rhs->type, 2, lhs, rhs); | |
lhs->binop = op; | |
} | |
return lhs; | |
} | |
// TERNARY ::= BINOP ('?' EXPR (':' TERNARY)?)? | |
static Node* parse_ternary(Parser* restrict p) { | |
if (parser_try_eatid(p, KW_new)) { | |
if (parser_try_eat(p, '[')) { | |
Node* count = parse_ternary(p); | |
parser_eat(p, ']'); | |
Type* t = type_new_array(type_new_var()); | |
return new_node(NODE_ALLOC_ARR, t, 1, count); | |
} else { | |
__debugbreak(); | |
} | |
} | |
Node* n = parse_binop(p, 0); | |
if (parser_try_eat(p, '?')) { | |
Node* mhs = parse_expr(p); | |
// if there's no false case, it returns void | |
Node* rhs; | |
if (parser_try_eat(p, ':')) { | |
rhs = parse_ternary(p); | |
} else { | |
rhs = new_node(NODE_NULL, type_new_mono(TYPE_VOID), 0); | |
} | |
type_unify(mhs->type, rhs->type); | |
n = new_node(NODE_TERNARY, rhs->type, 3, n, mhs, rhs); | |
} | |
return n; | |
} | |
static Type* parse_params(Parser* restrict p, Node* n) { | |
int count = 0; | |
Node* params[16]; | |
while (p->token.type != ')') { | |
Node* proj = new_proj(NULL, n, count); | |
if (parser_try_eat(p, '(')) { | |
proj->type = parse_params(p, proj); | |
} else { | |
Intern* param_name = parser_eat_ident(p); | |
proj->type = type_new_var(); | |
parser_push_decl(p, param_name, proj->type, proj); | |
printf(" PARAM: %.*s\n", param_name->len, param_name->data); | |
} | |
params[count++] = proj; | |
if (!parser_try_eat(p, ',')) { break; } | |
} | |
parser_eat(p, ')'); | |
Type* type; | |
if (count == 0) { | |
type = type_new_mono(TYPE_VOID); | |
} else if (count == 1) { | |
type = params[0]->type; | |
} else { | |
type = type_new_tuple(count); | |
FOR_N(i, 0, count) { | |
type->tuple.elems[i] = params[i]->type; | |
} | |
} | |
return type; | |
} | |
// PARAM ::= IDENT | |
// DECL ::= IDENT ':' '=' EXPR | |
// FUNC ::= 'fn' '(' (PARAM (',' PARAM)*)? ')' EXPR | |
// COMPOUND ::= '(' EXPR (';' EXPR)* ')' | |
// RETURN ::= 'return' EXPR | |
// EXPR ::= DECL | COMPOUND | FUNC | TERNARY | RETURN | |
static Node* parse_expr(Parser* restrict p) { | |
if (parser_try_eatid(p, KW_return)) { | |
Node* e = parse_ternary(p); | |
assert(p->top_fn->type->tag == TYPE_FUNC); | |
type_unify(e->type, p->top_fn->type->fn.ret); | |
return new_node(NODE_COMPOUND, type_new_mono(TYPE_VOID), 1, e); | |
} else if (parser_try_eatid(p, KW_while)) { | |
Node* cond = parse_ternary(p); | |
__debugbreak(); | |
return NULL; | |
} else if (parser_try_eatid(p, KW_let)) { | |
Intern* name = parser_eat_ident(p); | |
parser_eat(p, '='); | |
Node* n = parser_push_decl(p, name, type_new_var(), NULL); | |
set_in(n, parse_expr(p), 0); | |
return n; | |
} else if (parser_try_eatid(p, KW_fn)) { | |
Intern* name = NULL; | |
if (p->token.type == TKN_IDENT) { | |
name = parser_eat_ident(p); | |
printf("FN %.*s:\n", name->len, name->data); | |
} | |
Node* n = new_node(NODE_LAMBDA, NULL, 1, NULL); | |
parser_eat(p, '('); | |
Type* args = parse_params(p, n); | |
Type* type = n->type = type_new_func(args, type_new_var()); | |
Node* sym = NULL; | |
if (name != NULL) { | |
sym = parser_push_decl(p, name, type, n); | |
} | |
Node* old_top = p->top_fn; | |
p->top_fn = n; | |
set_in(n, parse_expr(p), 0); | |
type_unify(n->ins[0]->type, type->fn.ret); | |
// none of the symbols within the function escape | |
p->top_fn = old_top; | |
if (name != NULL) { | |
p->symtab = sym; | |
} | |
return n; | |
} else if (parser_try_eat(p, '{')) { | |
Node* lhs = parse_expr(p); | |
while (parser_try_eat(p, ';')) { | |
// empty statement just means void | |
Node* rhs; | |
if (parser_try_eat(p, ';')) { | |
rhs = new_node(NODE_NULL, type_new_mono(TYPE_VOID), 0); | |
} else { | |
rhs = parse_expr(p); | |
} | |
lhs = new_node(NODE_COMPOUND, rhs->type, 2, lhs, rhs); | |
} | |
parser_eat(p, '}'); | |
return lhs; | |
} else { | |
return parse_ternary(p); | |
} | |
} | |
static void print_type(Parser* restrict p, Type* restrict t) { | |
t = type_find(t); | |
// printf("T%d(", t->uid); | |
switch (t->tag) { | |
case TYPE_VAR: printf("TV%d", t->uid); break; | |
case TYPE_VOID: printf("void"); break; | |
case TYPE_INT: printf("int"); break; | |
case TYPE_FLT: printf("flt"); break; | |
case TYPE_INTFLT: printf("intflt"); break; | |
case TYPE_ARRAY: { | |
printf("["); | |
print_type(p, t->array_elem); | |
printf("]"); | |
} break; | |
case TYPE_TABLE: { | |
printf("{"); | |
FOR_N(i, 0, t->table.count) { | |
if (i) { printf(", "); } | |
printf("%s: ", t->table.elems[i].name->data); | |
print_type(p, t->table.elems[i].type); | |
} | |
printf("}"); | |
} break; | |
case TYPE_TUPLE: { | |
printf("("); | |
FOR_N(i, 0, t->tuple.count) { | |
if (i) { printf(" "); } | |
print_type(p, t->tuple.elems[i]); | |
} | |
printf(")"); | |
} break; | |
case TYPE_FUNC: { | |
print_type(p, t->fn.args); | |
printf(" -> "); | |
print_type(p, t->fn.ret); | |
} break; | |
} | |
// printf(")"); | |
} | |
static void ws_push_all(Worklist* restrict ws, Node* root) { | |
// BFS walk all the nodes | |
ws_push(ws, root); | |
for (size_t i = 0; i < ws->cnt; i++) { | |
Node* n = ws->arr[i]; | |
FOR_N(j, 0, n->in_cnt) { | |
if (n->ins[j]) { ws_push(ws, n->ins[j]); } | |
} | |
} | |
} | |
static void print(Parser* restrict p) { | |
Worklist ws; | |
ws_init(&ws, 100); | |
// push all nodes | |
for (Node* n = p->symtab; n != NULL; n = n->decl.prev) { | |
ws_push_all(&ws, n); | |
} | |
FOR_REV_N(i, 0, ws.cnt) { | |
Node* n = ws.arr[i]; | |
printf("%%%-4d: ", n->uid); | |
if (n->type) { | |
print_type(p, n->type); | |
} | |
printf(" = "); | |
switch (n->tag) { | |
case NODE_DECL: printf("decl %s", n->decl.name ? n->decl.name->data : "___"); break; | |
case NODE_APPLY: printf("apply"); break; | |
case NODE_LAMBDA: printf("lambda"); break; | |
case NODE_SYMBOL: printf("&sym %s", n->ins[0]->decl.name->data); break; | |
case NODE_ACCESS: printf(".%s", n->access->data); break; | |
case NODE_NULL: printf("null"); break; | |
case NODE_PROJ: printf("proj%d", n->proj_i); break; | |
case NODE_TUPLE: printf("tuple"); break; | |
case NODE_TERNARY: printf("ternary"); break; | |
case NODE_COMPOUND: printf("compound"); break; | |
case NODE_INT: printf("%"PRId64, n->num); break; | |
case NODE_REAL: printf("%f", n->flt); break; | |
case NODE_BINOP: { | |
switch (n->binop) { | |
case '+': printf("add"); break; | |
case '-': printf("sub"); break; | |
case '*': printf("mul"); break; | |
case '/': printf("div"); break; | |
case TKN_EQ: printf("eq"); break; | |
case TKN_NE: printf("ne"); break; | |
case TKN_GE: printf("ge"); break; | |
case TKN_LE: printf("le"); break; | |
case '>': printf(">"); break; | |
case '<': printf("<"); break; | |
default: abort(); | |
} | |
} break; | |
default: abort(); | |
} | |
printf(" ("); | |
FOR_N(j, 0, n->in_cnt) { | |
if (j) { printf(", "); } | |
if (n->ins[j]) { | |
printf("%%%d", n->ins[j]->uid); | |
} else { | |
printf("___"); | |
} | |
} | |
printf(")\n"); | |
} | |
} | |
static bool infer_transfer(Node* restrict n) { | |
switch (n->tag) { | |
case NODE_BINOP: { | |
if (type_unify(n->ins[0]->type, n->ins[1]->type)) { | |
return true; | |
} | |
return false; | |
} | |
case NODE_TERNARY: { | |
if (type_unify(n->ins[1]->type, n->ins[2]->type)) { | |
return true; | |
} | |
return false; | |
} | |
case NODE_APPLY: { | |
bool progress = false; | |
progress |= type_unify(n->ins[1]->type, n->apply.site->fn.args); | |
progress |= type_unify(n->ins[0]->type, n->apply.site); | |
return progress; | |
} | |
case NODE_DECL: { | |
if (n->decl.param >= 0) { | |
Type* fn_type = n->ins[0]->type; | |
assert(fn_type->tag == TYPE_FUNC); | |
if (fn_type->fn.args->tag == TYPE_TUPLE) { | |
return type_unify(fn_type->fn.args->tuple.elems[n->decl.param], n->type); | |
} else { | |
return type_unify(fn_type->fn.args, n->type); | |
} | |
} else if (n->ins[0] && n->type) { | |
return type_unify(n->ins[0]->type, n->type); | |
} | |
return false; | |
} | |
case NODE_LAMBDA: | |
assert(n->type->tag == TYPE_FUNC); | |
return type_unify(n->ins[0]->type, n->type->fn.ret); | |
case NODE_COMPOUND: | |
return type_unify(type_new_mono(TYPE_VOID), n->ins[0]->type); | |
case NODE_INT: | |
case NODE_REAL: | |
case NODE_NULL: | |
case NODE_PROJ: | |
case NODE_TUPLE: | |
case NODE_SYMBOL: | |
return false; | |
default: abort(); | |
} | |
} | |
static void infer(Node* symtab) { | |
Worklist ws; | |
ws_init(&ws, 100); | |
// push all nodes | |
for (Node* n = symtab; n != NULL; n = n->decl.prev) { | |
ws_push_all(&ws, n); | |
} | |
// monotone framework doing unification | |
for (Node* n; n = ws_pop(&ws), n != NULL;) { | |
printf("INFER %%%d\n", n->uid); | |
if (infer_transfer(n)) { | |
printf(" PROGRESS!\n"); | |
// add all neighbors | |
FOR_N(i, 0, n->in_cnt) { ws_push(&ws, n->ins[i]); } | |
FOR_N(i, 0, n->use_cnt) { ws_push(&ws, n->uses[i]); } | |
} | |
} | |
} | |
void compile(const char* src) { | |
if (!kw_init) { | |
kw_init = true; | |
#define X(name) KW_ ## name = parser_intern(sizeof(#name) - 1, #name); | |
#include "keywords.inc" | |
} | |
Parser p = { .src = src, .curr = src }; | |
parser_token(&p); | |
while (p.token.type != TKN_EOF) { | |
parse_expr(&p); | |
if (!parser_try_eat(&p, ';')) { break; } | |
} | |
print(&p); | |
infer(p.symtab); | |
print(&p); | |
} | |
int main(int argc, char** argv) { | |
const char* src = read_entire_file("test3.txt"); | |
compile(src); | |
__debugbreak(); | |
return 0; | |
} | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// wooyea | |
// fn iter(f, i, limit) { i < limit ? { f(i); iter(f, i+1, limit) } }; | |
// fn main() iter(fn(i) { i }, 0, 10); | |
// fn complex(r, i) (r, i); | |
fn cmul((a, b), (c, d)) (a*c - b*d, a*d + b*c); | |
fn main() { | |
let a = (1.0, 2); | |
let b = (3, 4); | |
cmul(a, b); | |
// array init | |
let c = [0, 1]; | |
c[0] | |
} | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// (A[], A -> B) -> B[] | |
fn map(src, f) { | |
let dst = new[#src]; | |
fn iter(i) { | |
i < #src ? dst[i] = f(src[i]); i=i+1 | |
}; | |
iter(); | |
dst | |
}; | |
// (intflt[], intflt -> intflt) -> intflt[] | |
fn foo(in) map(in, fn(x) 2*x - 1); | |
// flt[] | |
let c = foo([0.1, 0.4, 0.5, -0.3]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
fn update(e, dt) { | |
e.fire = e.fire - dt | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment