Skip to content

Instantly share code, notes, and snippets.

@RealNeGate
Last active July 29, 2024 22:27
Show Gist options
  • Save RealNeGate/3261bb54c21e0a0bade07b7bee6bd80d to your computer and use it in GitHub Desktop.
Save RealNeGate/3261bb54c21e0a0bade07b7bee6bd80d to your computer and use it in GitHub Desktop.
Learning Hindley-Milner
X(fn)
X(do)
X(let)
X(new)
X(while)
X(return)
#undef X
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <stdarg.h>
#include <string.h>
#include <assert.h>
#include <stdbool.h>
#include <sys/stat.h>
#include <inttypes.h>
#ifdef _WIN32
#define fileno _fileno
#define fstat _fstat
#define stat _stat
#endif
#define FOR_N(i, start, end) for (ptrdiff_t i = (start), _end_ = (end); i < _end_; i++)
#define FOR_REV_N(i, start, end) for (ptrdiff_t i = (end), _start_ = (start); i-- > _start_;)
typedef struct {
int len;
char data[];
} Intern;
static char* read_entire_file(const char* filepath) {
FILE* file;
assert(fopen_s(&file, filepath, "rb") == 0);
int descriptor = fileno(file);
struct stat file_stats;
if (fstat(descriptor, &file_stats) == -1) return NULL;
int length = file_stats.st_size;
char* data = malloc(length + 1);
fseek(file, 0, SEEK_SET);
size_t length_read = fread(data, 1, length, file);
data[length_read] = '\0';
fclose(file);
return data;
}
// murmur3 32-bit
uint32_t mur3_32(const void *key, int len, uint32_t h) {
// main body, work on 32-bit blocks at a time
for (int i=0;i<len/4;i++) {
uint32_t k = ((uint32_t*) key)[i]*0xcc9e2d51;
k = ((k << 15) | (k >> 17))*0x1b873593;
h = (((h^k) << 13) | ((h^k) >> 19))*5 + 0xe6546b64;
}
// load/mix up to 3 remaining tail bytes into a tail block
uint32_t t = 0;
uint8_t *tail = ((uint8_t*) key) + 4*(len/4);
switch(len & 3) {
case 3: t ^= tail[2] << 16;
case 2: t ^= tail[1] << 8;
case 1: {
t ^= tail[0] << 0;
h ^= ((0xcc9e2d51*t << 15) | (0xcc9e2d51*t >> 17))*0x1b873593;
}
}
// finalization mix, including key length
h = ((h^len) ^ ((h^len) >> 16))*0x85ebca6b;
h = (h ^ (h >> 13))*0xc2b2ae35;
return h ^ (h >> 16);
}
////////////////////////////////
// Type system
////////////////////////////////
typedef struct Type Type;
typedef struct {
Intern* name;
Type* type;
// data layout
} TableEntry;
struct Type {
enum {
TYPE_VAR,
// monotypes
TYPE_VOID,
TYPE_INT,
TYPE_FLT,
TYPE_INTFLT,
TYPE_FUNC,
TYPE_TUPLE,
TYPE_ARRAY,
TYPE_TABLE,
} tag;
// helpful for debugging & worklists
int uid;
// disjoint-set UF
Type* parent;
union {
struct {
// tuple types
int count;
Type** elems;
} tuple;
struct {
Type* args;
Type* ret;
} fn;
struct {
// sorted by address
int cap, count;
TableEntry* elems;
} table;
Type* array_elem;
};
};
// UF find
static Type* type_find(Type* a) {
// leader
Type* l = a;
while (l->parent != l) { l = l->parent; }
// path compaction
while (a->parent != a) {
Type* p = a->parent;
a->parent = l, a = p;
}
return l;
}
// UF union, returns true for progress
static bool type_union(Type* a, Type* b) {
a = type_find(a);
b = type_find(b);
if (a == b) {
return false;
}
a->parent = b;
return true;
}
// you don't actually need to use the return, either a or b will
// become "equivalent" to it, it's merely the current leader in the
// disjoint set
static bool type_unify(Type* a, Type* b) {
printf("UNIFY T%d = T%d\n", a->uid, b->uid);
a = type_find(a);
b = type_find(b);
if (a == b) {
// already matching
return false;
} else if (a->tag == TYPE_VAR && b->tag == TYPE_VAR) {
// both are type vars? well at least we know they're the same
return type_union(a, b);
} else if (a->tag == TYPE_VAR || b->tag == TYPE_VAR) {
// only one is a type var, the other is a resolved or semi-resolved
// type which is what should be made to flow up.
if (a->tag != TYPE_VAR) {
return type_union(b, a);
} else {
return type_union(a, b);
}
} else if (a->tag == b->tag) {
bool progress = false;
if (a->tag == TYPE_TUPLE) {
assert(a->tuple.count == b->tuple.count);
FOR_N(i, 0, a->tuple.count) {
progress |= type_unify(a->tuple.elems[i], b->tuple.elems[i]);
}
} else if (a->tag == TYPE_FUNC) {
progress |= type_unify(a->fn.args, b->fn.args);
progress |= type_unify(a->fn.ret, b->fn.ret);
} else {
progress |= type_union(a, b);
}
return progress;
} else {
if (a->tag > b->tag) {
Type* tmp = a;
a = b;
b = tmp;
}
// void & T = T
if (a->tag == TYPE_VOID) {
return type_union(a, b);
}
// int & flt = flt
if (a->tag == TYPE_FLT && b->tag == TYPE_INTFLT) {
return type_union(b, a);
}
__debugbreak();
return false;
}
}
static int TYPE_CNT = 0;
static Type* type_clone(Type* src) {
Type* t = malloc(sizeof(Type));
*t = *src;
t->parent = t;
t->uid = ++TYPE_CNT;
return t;
}
static Type* type_new_mono(int tag) {
Type* t = malloc(sizeof(Type));
*t = (Type){ .tag = tag, .uid = ++TYPE_CNT, .parent = t };
return t;
}
static Type* type_new_var(void) {
Type* t = malloc(sizeof(Type));
*t = (Type){ .tag = TYPE_VAR, .uid = ++TYPE_CNT, .parent = t };
return t;
}
static Type* type_new_array(Type* elem) {
Type* t = malloc(sizeof(Type));
*t = (Type){ .tag = TYPE_ARRAY, .uid = ++TYPE_CNT, .parent = t, .array_elem = elem };
return t;
}
static Type* type_new_table(Type* elem) {
Type* t = malloc(sizeof(Type));
*t = (Type){ .tag = TYPE_TABLE, .uid = ++TYPE_CNT, .parent = t, .table = { 4 } };
t->table.elems = malloc(t->table.cap * sizeof(TableEntry));
return t;
}
static Type* type_new_tuple(int arg_count) {
Type* t = malloc(sizeof(Type));
*t = (Type){ .tag = TYPE_TUPLE, .uid = ++TYPE_CNT, .parent = t, .tuple = { arg_count, malloc(arg_count * sizeof(Type*)) } };
return t;
}
static Type* type_new_func(Type* args, Type* ret) {
Type* t = malloc(sizeof(Type));
*t = (Type){ .tag = TYPE_FUNC, .uid = ++TYPE_CNT, .parent = t, .fn = { args, ret } };
return t;
}
////////////////////////////////
// Parse IR
////////////////////////////////
typedef struct Node Node;
struct Node {
enum {
NODE_NULL,
NODE_LEN,
NODE_INT,
NODE_REAL,
NODE_DECL,
NODE_COMPOUND,
// core
NODE_BINOP,
NODE_APPLY,
NODE_SYMBOL,
NODE_LAMBDA,
NODE_TERNARY,
// tuples
NODE_PROJ,
NODE_TUPLE,
// tables
NODE_ACCESS,
// array
NODE_SUBSCRIPT,
NODE_ALLOC_ARR,
} tag;
Type* type;
union {
struct {
int arg_count;
Node* body;
} fn;
struct {
// callsite type
Type* site;
} apply;
struct {
Node* prev;
Intern* name;
int param;
} decl;
Intern* access;
int proj_i;
uint64_t num;
double flt;
int binop;
};
// helpful for debugging & worklists
int uid;
// unordered list of uses
int use_cnt, use_cap;
Node** uses;
// not SSA, just a semi-stupid parse graph for inference
int in_cnt;
Node* ins[];
};
static int NODE_CNT;
static void set_in(Node* n, Node* v, int i) {
Node* old = n->ins[i];
n->ins[i] = v;
if (old == v) { return; }
if (old != NULL) {
// remove old user (if there's no other uses by this node)
bool found = false;
FOR_N(j, 0, n->in_cnt) if (n->ins[j] == old) {
found = true;
break;
}
if (!found) {
FOR_N(j, 0, old->use_cnt) {
if (old->uses[j] == n) {
old->use_cnt -= 1;
old->uses[j] = old->uses[old->use_cnt];
break;
}
}
}
}
if (v != NULL) {
// if it's already in the list, don't add it
FOR_N(j, 0, v->use_cnt) {
if (v->uses[j] == n) { return; }
}
// add new user
if (v->use_cnt == v->use_cap) {
v->use_cap *= 2;
v->uses = realloc(v->uses, v->use_cap * sizeof(Node*));
}
v->uses[v->use_cnt++] = n;
}
}
static Node* new_node(int tag, Type* type, int in_cnt, ...) {
Node* n = malloc(sizeof(Node) + in_cnt*sizeof(Node*));
*n = (Node){ tag, type, .uid = ++NODE_CNT, .in_cnt = in_cnt, .use_cap = 2, .uses = malloc(2 * sizeof(Node*)) };
va_list args;
va_start(args, in_cnt);
FOR_N(i, 0, in_cnt) {
n->ins[i] = NULL;
set_in(n, va_arg(args, Node*), i);
}
va_end(args);
return n;
}
static Node* new_node2(int tag, Type* type, int in_cnt) {
Node* n = malloc(sizeof(Node) + in_cnt*sizeof(Node*));
*n = (Node){ tag, type, .uid = ++NODE_CNT, .in_cnt = in_cnt, .use_cap = 2, .uses = malloc(2 * sizeof(Node*)) };
FOR_N(i, 0, in_cnt) { n->ins[i] = NULL; }
return n;
}
static Node* new_proj(Type* type, Node* n, int idx) {
Node* proj = new_node2(NODE_PROJ, type, 1);
proj->proj_i = idx;
set_in(proj, n, 0);
return proj;
}
////////////////////////////////
// Parser
////////////////////////////////
// TODO(NeGate): fill in good errors
static bool kw_init;
#define X(name) static Intern* KW_ ## name;
#include "keywords.inc"
typedef struct {
int visited_cap;
uint32_t* visited;
int cnt, cap;
Node** arr;
} Worklist;
static void ws_init(Worklist* ws, int cap) {
cap = 1ull << (64 - __builtin_clzll(cap - 1));
ws->visited_cap = cap / 32;
ws->visited = calloc(cap, sizeof(uint32_t));
ws->cnt = 0;
ws->cap = cap;
ws->arr = malloc(cap * sizeof(Node*));
}
static void ws_push(Worklist* ws, Node* n) {
size_t word_i = n->uid / 32u;
if (word_i >= ws->visited_cap) {
size_t new_cap = 1ull << (64 - __builtin_clzll(word_i - 1));
ws->visited = realloc(ws->visited, new_cap*sizeof(uint32_t));
FOR_N(i, ws->visited_cap, new_cap) { ws->visited[i] = 0; }
ws->visited_cap = new_cap;
} else if (ws->visited[word_i] & (1u << (n->uid % 32u))) {
return;
}
ws->visited[word_i] |= 1u << (n->uid % 32u);
if (n->uid >= ws->cap) {
size_t new_cap = 1ull << (64 - __builtin_clzll(n->uid - 1));
ws->arr = realloc(ws->arr, new_cap*sizeof(Node*));
ws->cap = new_cap;
}
ws->arr[ws->cnt++] = n;
}
static Node* ws_pop(Worklist* ws) {
if (ws->cnt == 0) {
return NULL;
}
Node* n = ws->arr[--ws->cnt];
ws->visited[n->uid] = false;
return n;
}
typedef struct {
const char* src;
const char* prev;
const char* curr;
// current token
struct {
enum {
TKN_EOF = 0,
TKN_IDENT = 128, TKN_INT, TKN_REAL,
// operators with an equals following them get +256
TKN_EQ = 256 + '=',
TKN_NE = 256 + '!',
TKN_GE = 256 + '>',
TKN_LE = 256 + '<',
} type;
Intern* str; // IDENT
uint64_t num; // INT
double flt; // REAL
} token;
// TODO(NeGate): make proper hashtables
Node* symtab;
Node* top_fn;
} Parser;
// 0.123
// 0.1
static bool num(int ch) { return ch >= '0' && ch <= '9'; }
// ident0 is the starting char
static bool ident0(int ch) { return (ch >= 'A' && ch <= 'Z') || (ch >= 'a' && ch <= 'z'); }
static bool ident1(int ch) { return ident0(ch) || num(ch); }
// TODO(NeGate): it's tested last but we really should define the sigils
static bool sigil(int ch) { return true; }
static bool space(int ch) { return ch == ' ' || ch == '\t' || ch == '\r' || ch == '\n'; }
// We intern all identifiers
static Intern* parser_intern(int len, const char* start) {
// TODO(NeGate): this prolly shouldn't be fixed size
static Intern* table[1024];
uint32_t hash = mur3_32(start, len, 0);
uint32_t first = hash & 1023, i = first;
do {
if (table[i] == NULL) {
table[i] = malloc(sizeof(Intern) + len + 1);
table[i]->len = len;
memcpy(table[i]->data, start, len);
table[i]->data[len] = 0;
return table[i];
} else if (table[i]->len == len && memcmp(table[i]->data, start, len) == 0) {
return table[i];
}
i = (i + 1) & 1023;
} while (i != first);
abort();
}
static bool parser_tokeq(Parser* restrict p, Intern* other) {
assert(p->token.str && p->token.type == TKN_IDENT);
return p->token.str == other;
}
static bool parser_token(Parser* restrict p) {
// skip ws & comments
loop: {
if (p->curr[0] == '/' && p->curr[1] == '/') {
do { p->curr++; } while (*p->curr && *p->curr != '\n');
goto loop;
}
if (space(*p->curr)) {
do { p->curr++; } while (space(*p->curr));
goto loop;
}
}
// eof
if (*p->curr == 0) {
p->token.type = TKN_EOF;
return false;
}
const char* start = p->curr;
int ch = *p->curr;
p->prev = start;
p->token.str = NULL;
if (num(ch)) {
unsigned long long n = 0;
do { n = (n * 10) + (ch - '0'), p->curr++, ch = *p->curr; } while (num(ch));
if (*p->curr == '.') {
p->curr += 1;
double m = 0, f = 0.1;
while (num(*p->curr)) {
m += (*p->curr - '0')*f, f *= 0.1;
p->curr++;
}
p->token.type = TKN_REAL;
p->token.flt = n+m;
} else {
p->token.type = TKN_INT;
p->token.num = n;
}
} else if (ident0(ch)) {
do { p->curr++, ch = *p->curr; } while (ident1(ch));
p->token.type = TKN_IDENT;
p->token.str = parser_intern(p->curr - start, start);
} else if (sigil(ch)) {
p->token.type = ch;
p->curr++;
if ((ch == '=' || ch == '>' || ch == '<' || ch == '!') && *p->curr == '=') {
p->token.type += 256;
p->curr += 1;
}
} else {
abort();
}
return true;
}
static bool parser_try_eatid(Parser* restrict p, Intern* other) {
if (p->token.type == TKN_IDENT && p->token.str == other) {
parser_token(p);
return true;
}
return false;
}
static bool parser_try_eat(Parser* restrict p, char ch) {
if (p->token.type == ch) {
parser_token(p);
return true;
}
return false;
}
static void parser_eat(Parser* restrict p, char ch) {
if (p->token.type != ch) {
abort();
}
parser_token(p);
}
static Intern* parser_eat_ident(Parser* restrict p) {
if (p->token.type != TKN_IDENT) {
abort();
}
Intern* str = p->token.str;
parser_token(p);
return str;
}
static int get_binop(int token_type) {
switch (token_type) {
case TKN_EQ: return 3;
case TKN_NE: return 3;
case TKN_GE: return 3;
case TKN_LE: return 3;
case '>': return 3;
case '<': return 3;
case '*': return 2;
case '/': return 2;
case '%': return 2;
case '+': return 1;
case '-': return 1;
default: return 0;
}
}
static Node* parser_push_decl(Parser* restrict p, Intern* name, Type* type, Node* init) {
Node* n = new_node2(NODE_DECL, type, 1);
n->decl.prev = p->symtab;
n->decl.name = name;
n->decl.param = -1;
if (name != NULL) {
p->symtab = n;
}
set_in(n, init, 0);
return n;
}
static Node* parser_lookup_decl(Parser* restrict p, Intern* name) {
Node* top = p->symtab;
while (top != NULL) {
assert(top->tag == NODE_DECL);
if (top->decl.name == name) {
return top;
}
top = top->decl.prev;
}
return NULL;
}
static Node* parse_expr(Parser* restrict p);
static Node* parse_ternary(Parser* restrict p);
static Node* parse_tuple(Parser* restrict p) {
if (!parser_try_eat(p, '(')) {
return NULL;
}
// tmp array should probably be resizable
int arg_cnt = 0;
Node* args[16];
while (p->token.type != ')' && p->token.type != TKN_EOF) {
args[arg_cnt++] = parse_expr(p);
if (!parser_try_eat(p, ',')) { break; }
}
parser_eat(p, ')');
assert(arg_cnt != 0);
if (arg_cnt == 1) {
return args[0];
}
// create placeholder type for application
Type* type = type_new_tuple(arg_cnt);
FOR_N(i, 0, arg_cnt) {
type->tuple.elems[i] = type_new_var();
type_unify(args[i]->type, type->tuple.elems[i]);
}
Node* n = new_node2(NODE_TUPLE, type, arg_cnt);
FOR_N(i, 0, arg_cnt) {
set_in(n, args[i], i);
}
return n;
}
// TUPLE ::= '(' EXPR (',' EXPR)* ')'
// ATOM ::= '#'? (IDENT | NUMBER | TUPLE) ('.' IDENT)* TUPLE?
static Node* parse_atom(Parser* restrict p) {
if (p->token.type == '#') {
parser_token(p);
Node* n = parse_atom(p);
return new_node(NODE_LEN, type_new_mono(TYPE_INT), 1, n);
}
Node* n = NULL;
if (p->token.type == TKN_IDENT) {
Node* sym = parser_lookup_decl(p, p->token.str);
if (sym == NULL) {
// parser_err();
abort();
}
parser_token(p);
n = new_node(NODE_SYMBOL, sym->type, 1, sym);
} else if (p->token.type == TKN_INT) {
n = new_node(NODE_INT, type_new_mono(TYPE_INTFLT), 0);
n->num = p->token.num;
parser_token(p);
} else if (p->token.type == TKN_REAL) {
n = new_node(NODE_REAL, type_new_mono(TYPE_FLT), 0);
n->flt = p->token.flt;
parser_token(p);
} else if (p->token.type == '(') {
n = parse_tuple(p);
}
loop: {
if (parser_try_eat(p, '.')) {
// table access
Intern* name = parser_eat_ident(p);
n = new_node(NODE_ACCESS, type_new_var(), 1, n);
n->access = name;
goto loop;
} else if (parser_try_eat(p, '[')) {
// array subscript
Node* idx = parse_ternary(p);
parser_eat(p, ']');
Type* t = type_new_var();
n = new_node(NODE_SUBSCRIPT, t, 2, n, idx);
goto loop;
} else if (p->token.type == '(') {
Node* args = parse_tuple(p);
Type* type = type_new_func(args->type, type_new_var());
Node* target = n;
n = new_node2(NODE_APPLY, type->fn.ret, 2);
n->apply.site = type;
set_in(n, target, 0);
set_in(n, args, 1);
goto loop;
}
}
return n;
}
static Node* parse_binop(Parser* restrict p, int min_prec) {
Node* lhs = parse_atom(p);
int prec;
while (prec = get_binop(p->token.type), prec != 0 && prec >= min_prec) {
int op = p->token.type;
parser_token(p);
Node* rhs = parse_binop(p, prec + 1);
type_unify(lhs->type, rhs->type);
lhs = new_node(NODE_BINOP, rhs->type, 2, lhs, rhs);
lhs->binop = op;
}
return lhs;
}
// TERNARY ::= BINOP ('?' EXPR (':' TERNARY)?)?
static Node* parse_ternary(Parser* restrict p) {
if (parser_try_eatid(p, KW_new)) {
if (parser_try_eat(p, '[')) {
Node* count = parse_ternary(p);
parser_eat(p, ']');
Type* t = type_new_array(type_new_var());
return new_node(NODE_ALLOC_ARR, t, 1, count);
} else {
__debugbreak();
}
}
Node* n = parse_binop(p, 0);
if (parser_try_eat(p, '?')) {
Node* mhs = parse_expr(p);
// if there's no false case, it returns void
Node* rhs;
if (parser_try_eat(p, ':')) {
rhs = parse_ternary(p);
} else {
rhs = new_node(NODE_NULL, type_new_mono(TYPE_VOID), 0);
}
type_unify(mhs->type, rhs->type);
n = new_node(NODE_TERNARY, rhs->type, 3, n, mhs, rhs);
}
return n;
}
static Type* parse_params(Parser* restrict p, Node* n) {
int count = 0;
Node* params[16];
while (p->token.type != ')') {
Node* proj = new_proj(NULL, n, count);
if (parser_try_eat(p, '(')) {
proj->type = parse_params(p, proj);
} else {
Intern* param_name = parser_eat_ident(p);
proj->type = type_new_var();
parser_push_decl(p, param_name, proj->type, proj);
printf(" PARAM: %.*s\n", param_name->len, param_name->data);
}
params[count++] = proj;
if (!parser_try_eat(p, ',')) { break; }
}
parser_eat(p, ')');
Type* type;
if (count == 0) {
type = type_new_mono(TYPE_VOID);
} else if (count == 1) {
type = params[0]->type;
} else {
type = type_new_tuple(count);
FOR_N(i, 0, count) {
type->tuple.elems[i] = params[i]->type;
}
}
return type;
}
// PARAM ::= IDENT
// DECL ::= IDENT ':' '=' EXPR
// FUNC ::= 'fn' '(' (PARAM (',' PARAM)*)? ')' EXPR
// COMPOUND ::= '(' EXPR (';' EXPR)* ')'
// RETURN ::= 'return' EXPR
// EXPR ::= DECL | COMPOUND | FUNC | TERNARY | RETURN
static Node* parse_expr(Parser* restrict p) {
if (parser_try_eatid(p, KW_return)) {
Node* e = parse_ternary(p);
assert(p->top_fn->type->tag == TYPE_FUNC);
type_unify(e->type, p->top_fn->type->fn.ret);
return new_node(NODE_COMPOUND, type_new_mono(TYPE_VOID), 1, e);
} else if (parser_try_eatid(p, KW_while)) {
Node* cond = parse_ternary(p);
__debugbreak();
return NULL;
} else if (parser_try_eatid(p, KW_let)) {
Intern* name = parser_eat_ident(p);
parser_eat(p, '=');
Node* n = parser_push_decl(p, name, type_new_var(), NULL);
set_in(n, parse_expr(p), 0);
return n;
} else if (parser_try_eatid(p, KW_fn)) {
Intern* name = NULL;
if (p->token.type == TKN_IDENT) {
name = parser_eat_ident(p);
printf("FN %.*s:\n", name->len, name->data);
}
Node* n = new_node(NODE_LAMBDA, NULL, 1, NULL);
parser_eat(p, '(');
Type* args = parse_params(p, n);
Type* type = n->type = type_new_func(args, type_new_var());
Node* sym = NULL;
if (name != NULL) {
sym = parser_push_decl(p, name, type, n);
}
Node* old_top = p->top_fn;
p->top_fn = n;
set_in(n, parse_expr(p), 0);
type_unify(n->ins[0]->type, type->fn.ret);
// none of the symbols within the function escape
p->top_fn = old_top;
if (name != NULL) {
p->symtab = sym;
}
return n;
} else if (parser_try_eat(p, '{')) {
Node* lhs = parse_expr(p);
while (parser_try_eat(p, ';')) {
// empty statement just means void
Node* rhs;
if (parser_try_eat(p, ';')) {
rhs = new_node(NODE_NULL, type_new_mono(TYPE_VOID), 0);
} else {
rhs = parse_expr(p);
}
lhs = new_node(NODE_COMPOUND, rhs->type, 2, lhs, rhs);
}
parser_eat(p, '}');
return lhs;
} else {
return parse_ternary(p);
}
}
static void print_type(Parser* restrict p, Type* restrict t) {
t = type_find(t);
// printf("T%d(", t->uid);
switch (t->tag) {
case TYPE_VAR: printf("TV%d", t->uid); break;
case TYPE_VOID: printf("void"); break;
case TYPE_INT: printf("int"); break;
case TYPE_FLT: printf("flt"); break;
case TYPE_INTFLT: printf("intflt"); break;
case TYPE_ARRAY: {
printf("[");
print_type(p, t->array_elem);
printf("]");
} break;
case TYPE_TABLE: {
printf("{");
FOR_N(i, 0, t->table.count) {
if (i) { printf(", "); }
printf("%s: ", t->table.elems[i].name->data);
print_type(p, t->table.elems[i].type);
}
printf("}");
} break;
case TYPE_TUPLE: {
printf("(");
FOR_N(i, 0, t->tuple.count) {
if (i) { printf(" "); }
print_type(p, t->tuple.elems[i]);
}
printf(")");
} break;
case TYPE_FUNC: {
print_type(p, t->fn.args);
printf(" -> ");
print_type(p, t->fn.ret);
} break;
}
// printf(")");
}
static void ws_push_all(Worklist* restrict ws, Node* root) {
// BFS walk all the nodes
ws_push(ws, root);
for (size_t i = 0; i < ws->cnt; i++) {
Node* n = ws->arr[i];
FOR_N(j, 0, n->in_cnt) {
if (n->ins[j]) { ws_push(ws, n->ins[j]); }
}
}
}
static void print(Parser* restrict p) {
Worklist ws;
ws_init(&ws, 100);
// push all nodes
for (Node* n = p->symtab; n != NULL; n = n->decl.prev) {
ws_push_all(&ws, n);
}
FOR_REV_N(i, 0, ws.cnt) {
Node* n = ws.arr[i];
printf("%%%-4d: ", n->uid);
if (n->type) {
print_type(p, n->type);
}
printf(" = ");
switch (n->tag) {
case NODE_DECL: printf("decl %s", n->decl.name ? n->decl.name->data : "___"); break;
case NODE_APPLY: printf("apply"); break;
case NODE_LAMBDA: printf("lambda"); break;
case NODE_SYMBOL: printf("&sym %s", n->ins[0]->decl.name->data); break;
case NODE_ACCESS: printf(".%s", n->access->data); break;
case NODE_NULL: printf("null"); break;
case NODE_PROJ: printf("proj%d", n->proj_i); break;
case NODE_TUPLE: printf("tuple"); break;
case NODE_TERNARY: printf("ternary"); break;
case NODE_COMPOUND: printf("compound"); break;
case NODE_INT: printf("%"PRId64, n->num); break;
case NODE_REAL: printf("%f", n->flt); break;
case NODE_BINOP: {
switch (n->binop) {
case '+': printf("add"); break;
case '-': printf("sub"); break;
case '*': printf("mul"); break;
case '/': printf("div"); break;
case TKN_EQ: printf("eq"); break;
case TKN_NE: printf("ne"); break;
case TKN_GE: printf("ge"); break;
case TKN_LE: printf("le"); break;
case '>': printf(">"); break;
case '<': printf("<"); break;
default: abort();
}
} break;
default: abort();
}
printf(" (");
FOR_N(j, 0, n->in_cnt) {
if (j) { printf(", "); }
if (n->ins[j]) {
printf("%%%d", n->ins[j]->uid);
} else {
printf("___");
}
}
printf(")\n");
}
}
static bool infer_transfer(Node* restrict n) {
switch (n->tag) {
case NODE_BINOP: {
if (type_unify(n->ins[0]->type, n->ins[1]->type)) {
return true;
}
return false;
}
case NODE_TERNARY: {
if (type_unify(n->ins[1]->type, n->ins[2]->type)) {
return true;
}
return false;
}
case NODE_APPLY: {
bool progress = false;
progress |= type_unify(n->ins[1]->type, n->apply.site->fn.args);
progress |= type_unify(n->ins[0]->type, n->apply.site);
return progress;
}
case NODE_DECL: {
if (n->decl.param >= 0) {
Type* fn_type = n->ins[0]->type;
assert(fn_type->tag == TYPE_FUNC);
if (fn_type->fn.args->tag == TYPE_TUPLE) {
return type_unify(fn_type->fn.args->tuple.elems[n->decl.param], n->type);
} else {
return type_unify(fn_type->fn.args, n->type);
}
} else if (n->ins[0] && n->type) {
return type_unify(n->ins[0]->type, n->type);
}
return false;
}
case NODE_LAMBDA:
assert(n->type->tag == TYPE_FUNC);
return type_unify(n->ins[0]->type, n->type->fn.ret);
case NODE_COMPOUND:
return type_unify(type_new_mono(TYPE_VOID), n->ins[0]->type);
case NODE_INT:
case NODE_REAL:
case NODE_NULL:
case NODE_PROJ:
case NODE_TUPLE:
case NODE_SYMBOL:
return false;
default: abort();
}
}
static void infer(Node* symtab) {
Worklist ws;
ws_init(&ws, 100);
// push all nodes
for (Node* n = symtab; n != NULL; n = n->decl.prev) {
ws_push_all(&ws, n);
}
// monotone framework doing unification
for (Node* n; n = ws_pop(&ws), n != NULL;) {
printf("INFER %%%d\n", n->uid);
if (infer_transfer(n)) {
printf(" PROGRESS!\n");
// add all neighbors
FOR_N(i, 0, n->in_cnt) { ws_push(&ws, n->ins[i]); }
FOR_N(i, 0, n->use_cnt) { ws_push(&ws, n->uses[i]); }
}
}
}
void compile(const char* src) {
if (!kw_init) {
kw_init = true;
#define X(name) KW_ ## name = parser_intern(sizeof(#name) - 1, #name);
#include "keywords.inc"
}
Parser p = { .src = src, .curr = src };
parser_token(&p);
while (p.token.type != TKN_EOF) {
parse_expr(&p);
if (!parser_try_eat(&p, ';')) { break; }
}
print(&p);
infer(p.symtab);
print(&p);
}
int main(int argc, char** argv) {
const char* src = read_entire_file("test3.txt");
compile(src);
__debugbreak();
return 0;
}
// wooyea
// fn iter(f, i, limit) { i < limit ? { f(i); iter(f, i+1, limit) } };
// fn main() iter(fn(i) { i }, 0, 10);
// fn complex(r, i) (r, i);
fn cmul((a, b), (c, d)) (a*c - b*d, a*d + b*c);
fn main() {
let a = (1.0, 2);
let b = (3, 4);
cmul(a, b);
// array init
let c = [0, 1];
c[0]
}
// (A[], A -> B) -> B[]
fn map(src, f) {
let dst = new[#src];
fn iter(i) {
i < #src ? dst[i] = f(src[i]); i=i+1
};
iter();
dst
};
// (intflt[], intflt -> intflt) -> intflt[]
fn foo(in) map(in, fn(x) 2*x - 1);
// flt[]
let c = foo([0.1, 0.4, 0.5, -0.3])
fn update(e, dt) {
e.fire = e.fire - dt
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment