Skip to content

Instantly share code, notes, and snippets.

@monomere
Last active March 24, 2025 15:06
Show Gist options
  • Save monomere/63e917f01e71ebb51dd86fd2f2b22236 to your computer and use it in GitHub Desktop.
Save monomere/63e917f01e71ebb51dd86fd2f2b22236 to your computer and use it in GitHub Desktop.
Tiny compiler written in python that generates text LLVM IR.
# License: MIT
# Author: monomere
# Changelog:
# Note: With every update, the doc comment is almost always updated
# to reflect new additions/improvements to the language.
# Update 7.1 - 2024/09/13:
# - Fix variadic arguments not being passed.
# Update 7 - 2024/01/27:
# - Add changelog note.
# - Fix bug in parser error handling (didn't `get` non-bracket tokens before)
# - Fix bug in `LTypeArray.align()`, it returned `self.of.size()` instead of `.align()`.
# - Add type name mangling.
# - Add type generics.
# - Change "Changes" to "Changelog".
# - `CodeGen._ir_ty(LTypeFunc)` works. `CodeGen._ir_func_ty(t, name="")`
# - Add C varargs support. Use `c_varargs` as last argument to specify varargs.
# Example: `extern func printf(s: *uint8, c_varargs): void;`
# Update 6 - 2024/01/26:
# - Add `qq_write_uint8s` to write strings.
# - Fix parsing optional type annotations in let statements.
# - Add named argument support (parsing).
# - Now code is compiled with PIC by default.
# - Add basic string literals.
# Requires a `struct uint8s { len: uint64, data: *uint8 }` type to be present.
# - Fix type path resolve bug (did first part twice).
# - Add structs.
# - Add pointer types (parsing and resolving).
# - Implement better parser error handling.
# Now the parser skips to the next top-level item
# when it reaches the global scope. TODO: Different bracket types?
# Update 5 - 2024/01/23:
# - Add implicit casting.
# - Fix CodeGen._gen_int_cast.
# - Rename `RstExpr.type` to `RstExpr.ty` to be consistent with everything else.
# - Add `as` casting.
# - Add `qq_read_int32`, `qq_stdin` to libqqrt.
# - Rename `qq_write_int` to `qq_write_int32` in libqqrt.
# - Fix parser precedence parsing bugs.
# - Add support for recursive functions.
# - Add support for basic type inference (see `expected_type`)
# For example, integer literals can now infer their type.
# - Add missing errors.
# Update 4 - 2024/01/22:
# - A single bool type, since LLVM supports 1-bit ints.
# - Fix block parsing bug (last stmt is duplicated into the tail).
# - Move from QBE to LLVM.
# Update 3 - 2024/01/22:
# - Add comparison operators.
# - Add if expressions.
# - Add boolean types: bool{8,16,32,64}.
# Update 2 - 2024/01/22:
# - Add assignment expressions.
# - Debug flags print the assembly generated by QBE.
# - Codegen for moving stuff. Qbe values can now be indirect and non-indirect.
# - Change license to MIT. Will maybe change to GPL or something... this needs to be open source.
# - Add some comments to codegen code. I still need to add more comments.
# - Add local variables.
# - Add size and align methods to LType.
# - Add additional CLI arguments (--use-asm and --use-qbe).
# - Fix void handling in codegen.
# Update 1 - 2024/01/21:
# - Improve some error messages.
# - Fix some parser and error reporter bugs.
# - Add down/up casting of integer types in codegen.
# - Fix binary operator typechecking (still limits to ints tho).
# - Add sized and signed/unsigned unsigned integer types: (u)int{8,16,32,64}.
# - Add `export` top-level decl attribute.
# - CLI improvement.
# - Parse more than one argument for functions.
# - Some code readability changes.
# - Remove unneeded debug stuff from `parse_expr`.
# - Fix formatting of doc comment.
# Initial release - 2024/01/21
'''
# Requirements
- LLVM (llc specifically)
- Assembler (Clang by default)
- libqqrt.o - runtime library
# Runtime library
## Should provide
- `void qq_write_uint8s(uint64_t fout, struct qq_slice_uint8 i)`
- `void qq_write_int32(uint64_t fout, int32_t i)`
- print an int and a newline to specified file.
- `int32_t qq_read_int32(uint64_t fout)`
- read an int from the specified file (akin to `int(input())` in python).
- `uint64_t qq_stdout()`
- `uint64_t qq_stderr()`
- `uint64_t qq_stdin()`
## Possible C code
```c
#include <stdio.h>
#include <stdint.h>
#include <stddef.h>
#include <inttypes.h>
extern uint64_t qq_stdout() { return (uint64_t)(uintptr_t)stdout; }
extern uint64_t qq_stderr() { return (uint64_t)(uintptr_t)stderr; }
extern uint64_t qq_stdin() { return (uint64_t)(uintptr_t)stdin; }
extern void qq_write_int32(uint64_t fout, int32_t i) {
FILE *pfout = (void*)(uintptr_t)fout;
fprintf(pfout, "%" PRId32, i);
}
struct qq_slice_uint8 {
uint64_t len;
uint8_t *data;
};
extern void qq_write_uint8s(uint64_t fout, struct qq_slice_uint8 s) {
FILE *pfout = (void*)(uintptr_t)fout;
fwrite(s.data, s.len, 1, pfout);
}
extern int32_t qq_read_int32(uint64_t fin) {
FILE *pfin = (void*)(uintptr_t)fin;
int32_t r = 0;
fscanf(pfin, " %" PRId32, &r);
return r;
}
```
# Example code
```go
extern func qq_read_int32(o: uint64): int32;
extern func qq_stdin(): uint64;
extern func read_int32() = qq_read_int32(qq_stdin());
struct slice[T] {
len: uint64,
ptr: *T
}
func fib(i: uint64): uint64 =
if i <= 1 { 1 }
else { fib(i - 1) + fib(i - 2) };
struct ProgramResults[T] {
index: T,
result: T
}
func run_program(): ProgramResults[uint64] {
let inp: uint64 = read_int32();
let res = fib(inp);
ProgramResults::[uint64]:(
index = inp,
result = res
)
}
extern func printf(s: *uint8, c_varargs): void;
export func main(argc: int32): int32 {
let r = run_program();
printf("The fibonacci number at index %lu is %lu.\n\0".ptr, r.index, r.result);
0
}
```
'''
from __future__ import annotations
import typing, dataclasses, enum, abc, sys, \
subprocess, argparse, os.path, functools, \
shlex, pprint
def print(*args, **kwargs):
__builtins__.print(';', *args, **kwargs, file=sys.stderr)
class Input(abc.ABC):
@abc.abstractmethod
def source(self) -> str:
...
@abc.abstractmethod
def filename(self) -> str:
...
def to_line_col(self, offset: int) -> tuple[str, int, int]:
src = self.source()
line = 0
col = 0
lstart = 0
for i in range(offset):
if i >= len(src): break
if src[i] == "\n":
lstart = i + 1
line += 1
col = 0
else:
col += 1
lend = lstart + 1
while lend < len(src) and src[lend] != "\n":
lend += 1
return (src[lstart : lend], line, col)
class InputStr(Input):
def __init__(self, s: str, fname: str = "<input>"):
self.s = s
self.fname = fname
def source(self) -> str:
return self.s
def filename(self) -> str:
return self.fname
@dataclasses.dataclass
class SrcSpan:
file: int
start: int
stop: int
def clone(self):
return SrcSpan(self.file, self.start, self.stop)
@dataclasses.dataclass
class SrcSpanInfo:
as_str: str
line_str: str
line: int
col: int
class InputManager:
inputs: list[Input] = []
@staticmethod
def add(inp: Input) -> int:
InputManager.inputs.append(inp)
return len(InputManager.inputs) - 1
@staticmethod
def span_info(span: SrcSpan) -> SrcSpanInfo:
inp = InputManager.inputs[span.file]
line_str, line, col = inp.to_line_col(span.start)
return SrcSpanInfo(
f"{inp.filename()}:{line + 1}:{col + 1}", line_str, line, col
)
class IrrecoverableError(Exception):
pass
class ErrorReporter:
report_count = 0
ignore = False
@staticmethod
def report(
msg: str,
span: SrcSpan | None, *,
kind: typing.Literal['error', 'note', 'success'] = 'error',
title: str | None = None
) -> Exception:
if ErrorReporter.ignore: return Exception(f"ignored {kind}: {title}: {msg} @{span}")
ErrorReporter.report_count += 1
fout = sys.stderr
info = InputManager.span_info(span) if span is not None else None
color = { "note": "34", "error": "1;31", "success": "1;32" }[kind]
title = { "note": "Note:", "error": "Error:", "success": "Success!" }[kind] \
if title is None else title
fout.write(f"\x1b[{color}m{title} ")
if info is not None:
fout.write("at ")
fout.write(info.as_str)
fout.write(" ")
fout.write(msg)
fout.write("\x1b[m\n")
if info is not None:
assert span is not None
pre = f"{info.line + 1} | "
fout.write(pre)
for c in info.line_str:
if c == '\t': fout.write(" ")
else: fout.write(c)
fout.write("\n")
for _ in range(len(pre)):
fout.write(" ")
for i in range(info.col):
if info.line_str[i] == '\t':
fout.write(" ")
fout.write(" ")
fout.write("^")
for i in range(span.stop - span.start - 1):
if info.col + i >= len(info.line_str) - 1: break
if info.line_str[info.col + i] == '\t':
fout.write("~")
fout.write("~")
fout.write("\n")
if ErrorReporter.report_count > 15:
raise IrrecoverableError()
return Exception(msg)
class TokenKind(enum.Enum):
ERROR = enum.auto()
EOF = enum.auto()
ID = enum.auto()
INT = enum.auto()
STR = enum.auto()
@dataclasses.dataclass
class Token:
span: SrcSpan
kind: TokenKind | str
value: str | int | None = None
OPEN_BR = ('{', '[', '(')
CLOSE_BR = ('}', ']', ')')
def tokenize(file: int, inp: str) -> typing.Iterable[Token]:
off = 0
EQ_OPS = ('>', '<', '=', '!')
cur_level = 0 # bracket level
def iseof() -> bool:
return off >= len(inp)
def peek():
if iseof():
return ""
return inp[off]
def get():
nonlocal off
c = peek()
off += 1
return c
def isdigit(c: str) -> bool:
return c.isdigit() or (ord(c) | 32) in range(97, 103) # a-f
def digitof(c: str) -> int:
o = ord(c)
if o < 65: # c < 'A', 0-9
return o - 48 # c - '0'
else:
return 10 + (o | 32) - 97 # lowercase c - 'a' + 10
while not iseof():
while peek().isspace():
off += 1
if peek().isalpha():
start = off
cs = [get()]
while peek().isalpha() or peek().isdigit() or peek() == '_':
cs.append(get())
yield Token(SrcSpan(file, start, off), TokenKind.ID, "".join(cs))
elif peek() == '"':
start = off
get()
cs = []
while not iseof() and peek() != '"':
if peek() == '\\':
get()
e = get()
match e:
case 'n': c = '\n'
case 't': c = '\t'
case 'r': c = '\r'
case '0': c = '\0'
case 'b': c = '\b'
case 'e': c = '\033'
case _: raise SyntaxError(f"Unknown escape character: '{e}'.")
cs.append(c)
else:
cs.append(get())
if get() != '"': # iseof() -> True
raise SyntaxError("Expected string end.")
yield Token(SrcSpan(file, start, off), TokenKind.STR, "".join(cs))
elif peek().isdigit():
start = off
n = ord(get()) - ord("0")
base = 10
if n == 0 and peek().isalpha(): # 0[xdob]
b = get()
match b:
case "x":
base = 16
case "d":
base = 10
case "o":
base = 8
case "b":
base = 2
case _:
raise SyntaxError(f"Unknown base: '{b}'.")
while peek() and isdigit(peek()):
n = n * base + digitof(get())
yield Token(SrcSpan(file, start, off), TokenKind.INT, n)
else:
start = off
c = get()
if not c: break
if c in EQ_OPS and peek() == '=':
get()
yield Token(SrcSpan(file, start, start + 1), c + '=')
elif c in OPEN_BR:
cur_level += 1
yield Token(SrcSpan(file, start, start), c, cur_level)
elif c in CLOSE_BR:
cur_level = cur_level - 1 if cur_level > 0 else 0
yield Token(SrcSpan(file, start, start), c, cur_level)
elif c == ':' and peek() == ':':
get()
yield Token(SrcSpan(file, start, start + 1), '::', cur_level)
else:
yield Token(SrcSpan(file, start, start), c)
yield Token(SrcSpan(file, off, off), TokenKind.EOF)
@dataclasses.dataclass
class Ast:
span: SrcSpan
@dataclasses.dataclass
class TypeAst(Ast):
pass
@dataclasses.dataclass
class TypeAstParam(TypeAst):
name: str
@dataclasses.dataclass
class TypeAstPtr(TypeAst):
to: TypeAst
@dataclasses.dataclass
class TypeAstPath(TypeAst):
root: TypeAst | None
names: list[str]
@dataclasses.dataclass
class TypeAstInst(TypeAst):
generic: TypeAst
params: list[TypeAst]
@dataclasses.dataclass
class TypeAstStruct(Ast):
attrs: set[str]
fields: list[AstDecl[TypeAst]]
TT = typing.TypeVar('TT', covariant=True)
@dataclasses.dataclass
class AstDecl(Ast, typing.Generic[TT]):
attrs: list[str]
name: str
ty: TT
value: ExprAst | None
@dataclasses.dataclass
class ExprAst(Ast):
pass
@dataclasses.dataclass
class ExprAstInt(ExprAst):
value: int
@dataclasses.dataclass
class ExprAstStr(ExprAst):
value: str
@dataclasses.dataclass
class ExprAstBool(ExprAst):
value: bool
@dataclasses.dataclass
class ExprAstError(ExprAst):
pass
class BinOp(enum.Enum):
ADD = enum.auto()
SUB = enum.auto()
MUL = enum.auto()
DIV = enum.auto()
MOD = enum.auto()
EQ = enum.auto()
NE = enum.auto()
LT = enum.auto()
GT = enum.auto()
LE = enum.auto()
GE = enum.auto()
ASSIGN = enum.auto()
@dataclasses.dataclass
class ExprAstBinOp(ExprAst):
lhs: ExprAst
rhs: ExprAst
op: BinOp
class UnOp(enum.Enum):
POS = enum.auto()
NEG = enum.auto()
NOT = enum.auto()
BIN_NOT = enum.auto()
@dataclasses.dataclass
class ExprAstUnOp(ExprAst):
val: ExprAst
op: UnOp
@dataclasses.dataclass
class ExprAstId(ExprAst):
names: list[str]
@dataclasses.dataclass
class ExprAstItem(ExprAst):
of: ExprAst
name: str
@dataclasses.dataclass
class ExprAstCall(ExprAst):
func: ExprAst
args: list[ExprAst]
named_args: list[tuple[str, ExprAst]]
@dataclasses.dataclass
class ExprAstCons(ExprAst):
ty: TypeAst
fields: list[tuple[str, ExprAst]]
@dataclasses.dataclass
class ExprAstCast(ExprAst):
expr: ExprAst
ty: TypeAst
@dataclasses.dataclass
class ExprAstBlock(ExprAst):
stmts: list[StmtAst]
tail: ExprAst | None
@dataclasses.dataclass
class ExprAstIf(ExprAst):
cond: ExprAst
true: ExprAst
false: ExprAst | None
@dataclasses.dataclass
class StmtAst(Ast):
pass
@dataclasses.dataclass
class StmtExprAst(StmtAst):
expr: ExprAst
@dataclasses.dataclass
class StmtDeclAst(StmtAst):
decl: AstDecl[TypeAst | None]
@dataclasses.dataclass
class AstFunc(Ast):
attrs: set[str]
name: str
args: list[AstDecl[TypeAst]]
ret: TypeAst | None
body: ExprAst | None
c_varargs: bool
@dataclasses.dataclass
class AstStruct(Ast):
name: str
ty_params: list[TypeAstParam] | None
ty: TypeAstStruct
class Parser:
@staticmethod
def _parser(fn):
@functools.wraps(fn)
def inner(self: Parser, *args, **kwargs):
saved_level = self._last_level
self._last_level = saved_level
saved_span = self._span
self._span = self._peek().span.clone()
r = fn(self, *args, **kwargs)
saved_span.stop = self._span.stop
self._span = saved_span
self._last_level = saved_level
return r
return inner
class FailedMiserably(Exception):
pass
OPS: dict[str, tuple[int, BinOp]] = {
'+': (10, BinOp.ADD),
'-': (10, BinOp.SUB),
'*': (20, BinOp.MUL),
'/': (20, BinOp.DIV),
'%': (20, BinOp.MOD),
'==': (5, BinOp.EQ),
'!=': (5, BinOp.NE),
'<=': (6, BinOp.LE),
'>=': (6, BinOp.GE),
'<': (6, BinOp.LT),
'>': (6, BinOp.GT),
'=': (1, BinOp.ASSIGN),
}
def __init__(self, toks: list[Token]):
self.idx = 0
self.toks = toks
self._rep_tok: Token | None = None
self._rep_count = 0
self._span: SrcSpan = SrcSpan(-1, -1, -1) # current @_parser span
self._current_level = 0 # current "bracket level"
self._last_level = 0 # saved _current_level.
def _peek(self): return self.toks[self.idx]
def _get(self):
t = self._peek()
if self.idx < len(self.toks) - 1:
self.idx += 1
self._span.stop = t.span.stop
return t
def _is(self, kind: TokenKind | str, allow_eof = False):
k = self._peek().kind
if kind == TokenKind.EOF: return k == kind
if k == TokenKind.EOF and not allow_eof:
ErrorReporter.report("EOF not allowed", self._peek().span)
raise IrrecoverableError()
return k == kind
def _is_kw(self, kw: str): return self._is(TokenKind.ID) and self._peek().value == kw
def _opt_eat(self, kind: TokenKind | str, allow_eof = False) -> Token | None:
if self._is(TokenKind.EOF) and not allow_eof:
ErrorReporter.report("EOF not allowed", self._peek().span)
raise IrrecoverableError()
if self._is(kind):
t = self._get()
if kind in OPEN_BR or kind in CLOSE_BR:
self._current_level = typing.cast(int, t.value)
return t
return None
def _handle_eat_fail(self, t: Token):
if t.kind in OPEN_BR or t.kind in CLOSE_BR:
assert isinstance(t.value, int)
if self._last_level <= t.value:
while not self._is(TokenKind.EOF):
t = self._peek()
if t.kind in OPEN_BR or t.kind in CLOSE_BR:
assert isinstance(t.value, int)
if t.value == 0:
self._get()
break
self._get()
raise Parser.FailedMiserably() # go back to toplevel
def _eat(self, kind: TokenKind | str) -> Token | None:
if t := self._opt_eat(kind): return t
t = self._peek()
ErrorReporter.report(f"Expected {kind}.", t.span)
self._handle_eat_fail(t)
return None
def _opt_eat_kw(self, kw: str, allow_eof = False) -> Token | None:
if self._is(TokenKind.EOF) and not allow_eof:
ErrorReporter.report("EOF not allowed", self._peek().span)
return None
return self._get() if self._is_kw(kw) else None
def _eat_kw(self, kw: str) -> Token | None:
if t := self._opt_eat_kw(kw): return t
t = self._peek()
ErrorReporter.report(f"Expected keyword '{kw}'.", t.span)
self._handle_eat_fail(t)
return None
def _eat_name(self) -> str:
if t := self._eat(TokenKind.ID):
return typing.cast(str, t.value)
self._get()
return "<error>"
@_parser
def parse_type_atom(self) -> TypeAst:
name = self._eat_name()
n = TypeAstPath(self._span, None, [name])
while self._is('::') or self._is('['):
if self._is('['): # generic instantiation
n.span = n.span.clone() # make sure the span doesn't change.
self._get()
params: list[TypeAst] = []
while not self._is(']'):
params.append(self.parse_type())
if not self._opt_eat(','): break
self._eat(']')
n = TypeAstInst(self._span.clone(), n, params)
else:
if not isinstance(n, TypeAstPath):
n = TypeAstPath(self._span.clone(), n, [])
self._get()
n.names.append(self._eat_name())
n.span = n.span.clone() # make sure it doesn't go anywhere.
return n
@_parser
def parse_type_prefix(self) -> TypeAst:
if self._opt_eat('*'):
other = self.parse_type_prefix()
return TypeAstPtr(self._span, other)
return self.parse_type_atom()
# @_parser XXX: don't forget
def parse_type(self) -> TypeAst:
return self.parse_type_prefix()
@_parser
def parse_atom(self) -> ExprAst:
if t := self._opt_eat(TokenKind.INT):
return ExprAstInt(self._span, typing.cast(int, t.value))
if t := self._opt_eat(TokenKind.STR):
return ExprAstStr(self._span, typing.cast(str, t.value))
if t := self._opt_eat(TokenKind.ID):
if t.value == "true": return ExprAstBool(self._span, True)
if t.value == "false": return ExprAstBool(self._span, False)
n = TypeAstPath(self._span, None, [typing.cast(str, t.value)])
while self._opt_eat('::'):
if self._is('['): # generic instantiation
n.span = n.span.clone() # make sure the span doesn't change.
self._get()
params: list[TypeAst] = []
while not self._is(']'):
params.append(self.parse_type())
if not self._opt_eat(','): break
self._eat(']')
n = TypeAstInst(self._span.clone(), n, params)
else:
if not isinstance(n, TypeAstPath):
n = TypeAstPath(self._span.clone(), n, [])
n.names.append(self._eat_name())
n.span = n.span.clone() # make sure it doesn't go anywhere.
if self._opt_eat(':'): # constructors
self._eat('(')
_, fields = self.parse_call_args(allow_positional=False)
self._eat(')')
return ExprAstCons(self._span, n, fields)
if not isinstance(n, TypeAstPath) or n.root is not None:
ErrorReporter.report("Expected an expression, not a type.", self._span)
return ExprAstError(self._span)
return ExprAstId(self._span, n.names)
ErrorReporter.report("Expected an atom.", self._peek().span)
return ExprAstError(self._span)
def parse_call_args(
self, *,
sep: TokenKind | str = ',',
end: TokenKind | str = ')',
allow_positional = True
) -> tuple[list[ExprAst], list[tuple[str, ExprAst]]]:
args: list[ExprAst] = []
named_args: list[tuple[str, ExprAst]] = []
while not self._is(end):
had_parens = self._is('(')
e = self.parse_expr()
if isinstance(e, ExprAstBinOp) and e.op == BinOp.ASSIGN and not had_parens:
if not isinstance(e.lhs, ExprAstId) or len(e.lhs.names) != 1:
ErrorReporter.report(
"Arbitrary assignment expressions are not allowed in function arguments.",
e.span
)
ErrorReporter.report("Surround with parentheses", e.span, kind='note')
else:
named_args.append((e.lhs.names[0], e.rhs))
else:
if len(named_args) > 0:
ErrorReporter.report("Positional arguments not allowed after named arguments.", e.span)
if not allow_positional:
ErrorReporter.report("Positional arguments not allowed here.", e.span)
args.append(e)
if not self._opt_eat(sep):
break
return args, named_args
@_parser
def parse_suffix_expr(self) -> ExprAst:
n = self.parse_atom()
while self._is('(') or self._is('.'):
if self._opt_eat('('):
args, named_args = self.parse_call_args()
self._eat(')')
n = ExprAstCall(self._span.clone(), n, args, named_args)
elif self._opt_eat('.'):
name = self._eat_name()
n = ExprAstItem(self._span.clone(), n, name)
else: assert False # unreachable hopefully
while self._opt_eat_kw('as'):
t = self.parse_type()
n = ExprAstCast(self._span.clone(), n, t)
return n
@_parser
def parse_bin_expr(self, min_prec: int = 0) -> ExprAst:
lhs = self.parse_suffix_expr()
while (opk := self._peek().kind) in Parser.OPS:
prec, op = Parser.OPS[typing.cast(str, opk)]
if prec <= min_prec: break
self._get()
rhs = self.parse_bin_expr(prec)
lhs = ExprAstBinOp(self._span, lhs, rhs, op)
return lhs
@_parser
def parse_stmt(self) -> tuple[StmtAst, int]: # ast, end[1=; 2=} 0]
while self._opt_eat(';'): continue
if self._opt_eat_kw('let'):
decl = self.parse_decl_opt_type(allow_value=True, opt_type=True)
self._eat(';')
return StmtDeclAst(self._span, decl), 1
e = self.parse_expr()
had_semicolon = 0
if isinstance(e, ExprAstIf): had_semicolon = 2
if self._opt_eat(';') is not None: had_semicolon = 1
return StmtExprAst(e.span, e), had_semicolon
@_parser
def parse_block(self) -> ExprAstBlock:
self._eat('{')
stmts: list[StmtAst] = []
last, end = self.parse_stmt()
while end:
if self._is('}'):
if end == 1: # { ... x; }
stmts.append(last)
last = None
end = None
# else: # { ... if ... } } or { ... x }
break
else:
stmts.append(last)
last = None
last, end = self.parse_stmt()
self._eat('}')
tail = None
if last is not None:
if not isinstance(last, StmtExprAst):
ErrorReporter.report("Expected an expression, not a statement.", last.span)
else:
tail = last.expr
return ExprAstBlock(self._span, stmts, tail)
@_parser
def parse_expr(self) -> ExprAst:
if self._is('{'):
return self.parse_block()
if self._opt_eat_kw('if'):
cond = self.parse_expr()
true = self.parse_block()
false = None
if self._opt_eat_kw('else'):
if self._is_kw('if'): # allow 'else if'
false = self.parse_expr()
else:
false = self.parse_block()
return ExprAstIf(self._span, cond, true, false)
if self._is(TokenKind.EOF):
return ExprAstError(self._span)
return self.parse_bin_expr()
@_parser
def parse_decl_opt_type(self, *, allow_value: bool, opt_type: bool = False) -> AstDecl[TypeAst | None]:
name = self._eat_name()
ty = None
if not opt_type or self._is(':'):
self._eat(':')
ty = self.parse_type()
val = None
if self._opt_eat('='):
val = self.parse_expr()
if not allow_value:
ErrorReporter.report("Value not allowed here.", val.span)
val = None
return AstDecl[TypeAst | None](self._span, [], name, ty, val)
@_parser
def parse_decl(self, *, allow_value: bool) -> AstDecl[TypeAst]:
r = self.parse_decl_opt_type(allow_value=allow_value, opt_type=False)
return typing.cast(AstDecl[TypeAst], r)
# NOT a @_parser
def parse_func(self, attrs: set[str]) -> AstFunc | None:
name = self._eat_name()
args: list[AstDecl[TypeAst]] = []
c_varargs = False
self._eat('(')
while not self._opt_eat(')'):
if self._opt_eat_kw("c_varargs"):
c_varargs = True
self._eat(')')
break
args.append(self.parse_decl(allow_value=False))
if not self._opt_eat(','):
self._eat(')')
break
ret = None
if self._opt_eat(':'):
ret = self.parse_type()
if self._is('{'):
body = self.parse_expr()
elif self._opt_eat(';'):
body = None
else:
self._eat('=')
body = self.parse_expr()
self._eat(';')
return AstFunc(
span = self._span,
attrs = attrs,
name = name,
args = args,
ret = ret,
body = body,
c_varargs = c_varargs
)
def parse_using(self):
return None
def parse_type_params(
self,
sep: TokenKind | str = ',',
end: TokenKind | str = ']'
) -> list[TypeAstParam]:
params: list[TypeAstParam] = []
while not self._is(end):
s = self._peek().span.clone()
name = self._eat_name()
params.append(TypeAstParam(s, name))
if not self._opt_eat(sep): break
return params
@_parser
def parse_struct(self) -> AstStruct:
name = self._eat_name()
ty_params = None
if self._opt_eat('['):
ty_params = self.parse_type_params()
self._eat(']')
self._eat('{')
fields: list[AstDecl[TypeAst]] = []
while not self._is('}'):
decl = self.parse_decl(allow_value=False)
fields.append(decl)
if self._is('}'): break
self._eat(',')
self._eat('}')
return AstStruct(
self._span,
name = name,
ty_params = ty_params,
ty = TypeAstStruct(
self._span,
attrs = set(),
fields = fields
)
)
@_parser
def parse_toplevel(self) -> AstFunc | AstStruct | None:
if self._opt_eat_kw("using"):
return self.parse_using()
if self._opt_eat_kw("struct"):
return self.parse_struct()
attrs = set[str]()
if self._opt_eat_kw("export"): attrs.add("export")
if self._opt_eat_kw("extern"): attrs.add("extern")
if self._opt_eat_kw("func"):
return self.parse_func(attrs)
ErrorReporter.report("Expected a top-level item.", self._peek().span)
@dataclasses.dataclass
class Rst:
span: SrcSpan
@dataclasses.dataclass
class ExprRst(Rst):
ty: LType
@dataclasses.dataclass
class ExprRstInt(ExprRst):
value: int
@dataclasses.dataclass
class ExprRstStr(ExprRst):
value: bytes
@dataclasses.dataclass
class ExprRstAssign(ExprRst):
var: LVar
val: ExprRst
@dataclasses.dataclass
class ExprRstBinOp(ExprRst):
lhs: ExprRst
rhs: ExprRst
op: BinOp
op_ty: LType
@dataclasses.dataclass
class ExprRstUnOp(ExprRst):
val: ExprRst
op: UnOp
@dataclasses.dataclass
class ExprRstVar(ExprRst):
var: LVar
@dataclasses.dataclass
class ExprRstIf(ExprRst):
cond: ExprRst
true: ExprRst
false: ExprRst | None
@dataclasses.dataclass
class ExprRstBlock(ExprRst):
stmts: list[StmtRst]
tail: ExprRst | None
@dataclasses.dataclass
class ExprRstCall(ExprRst):
func: ExprRst
args: list[ExprRst]
@dataclasses.dataclass
class ExprRstCons(ExprRst):
ty: LTypeStruct
fields: list[ExprRst]
@dataclasses.dataclass
class ExprRstCast(ExprRst):
expr: ExprRst
@dataclasses.dataclass
class StmtRst(Rst):
pass
@dataclasses.dataclass
class StmtExprRst(StmtRst):
expr: ExprRst
@dataclasses.dataclass
class StmtDeclRst(StmtRst):
ty: LType
name: str
id: int
val: ExprRst | None
class Scope(abc.ABC):
@abc.abstractmethod
def get_var(self, name: str, span: SrcSpan | None) -> LVar:
...
@abc.abstractmethod
def has_type(self, name: str) -> bool:
...
@abc.abstractmethod
def get_type(self, name: str, span: SrcSpan | None) -> LType:
...
@dataclasses.dataclass
class LType(abc.ABC):
@abc.abstractmethod
def size(self) -> int: ...
@abc.abstractmethod
def align(self) -> int: ...
@abc.abstractmethod
def mangled_name(self) -> str: ...
@dataclasses.dataclass
class LTypeError(LType):
def __str__(self) -> str:
return "<error>"
def mangled_name(self) -> str:
return "E"
def align(self) -> int:
return 1
def size(self) -> int:
return 1
@dataclasses.dataclass
class LTypeOpaque(LType):
name: str
def __str__(self) -> str:
return f"{self.name}"
def mangled_name(self) -> str:
return f"O{len(self.name)}{self.name}"
def align(self) -> int:
return -1
def size(self) -> int:
return -1
@dataclasses.dataclass
class LTypeInt(LType):
VALID_BITS: typing.ClassVar[set[int]] = {8, 16, 32, 64}
bits: int
signed: bool
def __post_init__(self):
# assert self.bits in LTypeInt.VALID_BITS
pass
def extend_to(self, other: LTypeInt) -> None | LTypeInt:
if other.bits > self.bits: return other
if other.bits < self.bits: return None
if self.signed == other.signed: return None
if not other.signed: return other
return None
def __str__(self) -> str:
return ("u" if not self.signed else "") + f"int{self.bits}"
def mangled_name(self) -> str:
return f"I{self.bits // 8}{"s" if self.signed else "u"}"
def size(self) -> int: return self.bits // 8
def align(self) -> int: return max(4, self.size())
@dataclasses.dataclass
class LTypeBool(LTypeInt):
def __init__(self):
super().__init__(1, False)
def __str__(self) -> str:
return "bool"
def mangled_name(self) -> str:
return "B"
def size(self) -> int: return 1
def align(self) -> int: return 1
@dataclasses.dataclass
class LTypeVoid(LType):
def __str__(self) -> str:
return "void"
def mangled_name(self) -> str:
return "V"
def size(self) -> int: return 0
def align(self) -> int: return 1
@dataclasses.dataclass
class LTypePtr(LType):
to: LType
def __str__(self) -> str:
return f"*{str(self.to)}"
def mangled_name(self) -> str:
o = self.to.mangled_name()
return f"P{len(o)}{o}"
def size(self) -> int: return 8
def align(self) -> int: return 8
@dataclasses.dataclass
class LTypeFunc(LType):
args: list[LType]
ret: LType
c_varargs: bool
def __str__(self) -> str:
v = ", c_varargs" if self.c_varargs else ""
return f"func({', '.join(map(str, self.args))}{v}) -> {str(self.ret)}"
def mangled_name(self) -> str:
v = "v" if self.c_varargs else ""
r = self.ret.mangled_name()
g = (a.mangled_name() for a in self.args)
return f"F{v}{len(r)}{r}{len(self.args)}_{''.join(f"{len(a)}{a}" for a in g)}"
def size(self) -> int: return 8
def align(self) -> int: return 8
@dataclasses.dataclass
class LTypeArray(LType):
of: LType
len: int
def __str__(self) -> str:
return f"[{self.len}]{str(self.of)}"
def mangled_name(self) -> str:
v = self.of.mangled_name()
return f"A{self.len}x{len(v)}{v}"
def size(self) -> int: return self.of.size() * self.len
def align(self) -> int: return self.of.align()
@dataclasses.dataclass
class LTypeGeneric(LType, Scope):
parent: Scope
params: list[TypeAstParam]
sub: LType
def __init__(
self,
parent: Scope,
params: list[TypeAstParam],
get_sub: typing.Callable[[typing.Self], LType]
):
self.parent = parent
self.params = params
self.sub = get_sub(self)
def get_var(self, name: str, span: SrcSpan | None) -> LVar:
return self.parent.get_var(name, span)
def has_type(self, name: str) -> bool:
for param in self.params:
if param.name == name:
return True
return self.parent.has_type(name)
def get_type(self, name: str, span: SrcSpan | None) -> LType:
for param in self.params:
if param.name == name:
return LTypeOpaque(param.name)
return self.parent.get_type(name, span)
def __str__(self) -> str:
return f"[{', '.join(f'{f.name}' for f in self.params)}] {self.sub}"
def mangled_name(self) -> str:
raise NotImplementedError("Generic types can't be mangled.")
def size(self) -> int: return -1
def align(self) -> int: return -1
@dataclasses.dataclass
class LTypeStruct(LType, Scope):
@dataclasses.dataclass
class Field:
name: str
index: int
offset: int
ty: LType
name: str
fields: list[Field]
types: dict[str, LType]
parent: Scope | None
_size: int = 0
_params: list[LType] | None = None
def __init__(
self,
name: str,
fields: list[typing.Callable[[typing.Self], tuple[str, LType]]],
parent: Scope | None,
params: list[LType] | None = None
):
self.name = name
self.types = {}
self.parent = parent
self.fields = []
self._size = 0
self._params = params
for i, field in enumerate(fields):
name, ty = field(self)
self.fields.append(LTypeStruct.Field(
name = name,
index = i,
offset = self._size,
ty = ty
))
self._size = (self._size + (ty.align() - 1)) & -ty.align()
self._size += ty.size()
def __str__(self) -> str:
if self.name is not None: return f"struct {self.name}"
return f"struct {{ {', '.join(f'{f.name}: {f.ty}' for f in self.fields)} }}"
def mangled_name(self) -> str:
n = self.name
x = ""
if n.startswith('!'):
n = n[1:]
x = "x"
if self._params is not None:
ps = "".join(f"{len(t)}{t}" for t in (t.mangled_name() for t in self._params))
n = self.name[:self.name.index('[')]
return f"St{x}{len(n)}{n}{len(self._params)}_{ps}"
return f"S{x}{len(n)}{n}"
def get_field(self, name: str, span: SrcSpan | None) -> Field | None:
for field in self.fields:
if field.name == name:
return field
ErrorReporter.report(f"No field '{name}' in {str(self)}.", span)
return None
def get_var(self, of: ExprRst, name: str, span: SrcSpan | None) -> LVar:
for field in self.fields:
if field.name == name:
return LStructFieldVar(field.ty, field.name, of, field.offset, field.index)
if self.parent is not None: return self.parent.get_var(name, span)
ErrorReporter.report(f"No '{name}' in {str(self)}.", span)
return LErrorVar(LTypeError(), name)
def has_type(self, name: str) -> bool:
return name in self.types
def get_type(self, name: str, span: SrcSpan | None) -> LType:
if name in self.types:
return self.types[name]
if self.parent is not None: return self.parent.get_type(name, span)
ErrorReporter.report(f"No type {name} in {str(self)}.", span)
return LTypeError()
def size(self) -> int: return self._size
def align(self) -> int: return max(f.ty.align() for f in self.fields)
@dataclasses.dataclass
class LVar:
ty: LType
name: str
@dataclasses.dataclass
class LErrorVar(LVar):
pass
@dataclasses.dataclass
class LLocalVar(LVar):
id: int
@dataclasses.dataclass
class LArgVar(LVar):
pass
@dataclasses.dataclass
class LGlobalVar(LVar):
pass
@dataclasses.dataclass
class LStructFieldVar(LVar):
of: ExprRst
offset: int
index: int
@dataclasses.dataclass
class LUnionFieldVar(LVar):
pass
@dataclasses.dataclass
class LPointerDerefVar(LVar):
pointer: ExprRst
@dataclasses.dataclass
class FuncScope(Scope, abc.ABC):
'''For scopes that are only inside of functions.'''
func: LFunc
class LocalScope(FuncScope, abc.ABC):
@abc.abstractmethod
def add_local(self, ty: LType, name: str, span: SrcSpan | None) -> LLocalVar:
...
@dataclasses.dataclass
class BlockScope(LocalScope):
parent: Scope
vars: dict[str, LLocalVar]
types: dict[str, LType]
def get_var(self, name: str, span: SrcSpan | None) -> LVar:
if name in self.vars:
return self.vars[name]
if self.parent is not None:
return self.parent.get_var(name, span)
ErrorReporter.report(f"No such variable: '{name}'", span)
return LErrorVar(LTypeError(), name)
def has_type(self, name: str) -> bool:
if name in self.types: return True
if self.parent is not None:
return self.parent.has_type(name)
return False
def get_type(self, name: str, span: SrcSpan | None) -> LType:
if name in self.types:
return self.types[name]
if self.parent is not None:
return self.parent.get_type(name, span)
ErrorReporter.report(f"No type {name} in {str(self)}", span)
return LTypeError()
def add_local(self, ty: LType, name: str, span: SrcSpan | None) -> LLocalVar:
if name in self.vars:
self.vars[name].name += " shadowed"
v = self.func.add_local(ty, name, span)
self.vars[name] = v
return v
@dataclasses.dataclass
class LFunc(LVar, FuncScope):
module: ModuleScope
span: SrcSpan
arg_names: list[str]
attrs: set[str]
body: ExprRst | None
locls: list[LLocalVar]
next_local_id = 0
types: dict[str, LType]
ty: LTypeFunc
def __init__(
self,
module: ModuleScope,
span: SrcSpan,
name: str,
ret_ty: LType | None,
arg_types: list[LType],
c_varargs: bool,
arg_names: list[str],
attrs: set[str],
body: ExprAst | None,
resolver: Resolver
):
self.func = self
self.name = name
self.span = span
self.module = module
self.arg_names = arg_names
self.attrs = attrs
self.code = []
self.locls = []
self.types = {}
# this might not be the full type!
self.ty = LTypeFunc(arg_types, LTypeError() if ret_ty is None else ret_ty, c_varargs)
# we do this stuff in the constructor because we need `self`.
rbody = resolver.resolve_expr(self, body, ret_ty) if body else None
self.body = rbody
if rbody is None:
if ret_ty is None:
ErrorReporter.report("Cannot infer return type of function with no body.", span)
rret_ty = LTypeError()
else:
rret_ty = ret_ty
else:
rret_ty = rbody.ty
if ret_ty is not None and rret_ty != ret_ty:
ErrorReporter.report(
f"Mismatching return types. Expected {ret_ty}, got {rret_ty}.",
span
)
# Ensure the type is correct now.
self.ty = LTypeFunc(arg_types, rret_ty, c_varargs)
def get_var(self, name: str, span: SrcSpan | None) -> LVar:
if name == self.name:
return self
if name in self.arg_names:
i = self.arg_names.index(name)
return LArgVar(self.ty.args[i], name)
return self.module.get_var(name, span)
def has_type(self, name: str) -> bool:
if name in self.types: return True
return self.module.has_type(name)
def get_type(self, name: str, span: SrcSpan | None) -> LType:
if name in self.types:
return self.types[name]
return self.module.get_type(name, span)
def add_local(self, ty: LType, name: str, span: SrcSpan | None) -> LLocalVar:
var = LLocalVar(ty, name, self.next_local_id)
self.locls.append(var)
self.next_local_id += 1
return var
class ModuleScope(Scope):
def __init__(self, name: str):
self.name = name
self.globs: dict[str, LGlobalVar] = {}
self.funcs: dict[str, LFunc] = {}
self.types: dict[str, LType] = {}
def get_var(self, name: str, span: SrcSpan | None) -> LVar:
if name in self.globs:
return self.globs[name]
elif name in self.funcs:
return self.funcs[name]
ErrorReporter.report(f"No variable '{name}' in module.", span)
return LErrorVar(LTypeError(), name)
def has_type(self, name: str) -> bool:
return name in self.types
def get_type(self, name: str, span: SrcSpan | None) -> LType:
if name in self.types:
return self.types[name]
ErrorReporter.report(f"No type '{name}' in module.", span)
return LTypeError()
def __str__(self) -> str:
r = f"module {self.name} {{\n"
for v in self.globs.values():
r += f" {v.name}: {v.ty}\n"
for v in self.funcs.values():
r += f" {v}\n"
for name, v in self.types.items():
r += f" {name} = {v}\n"
r += "}"
return r
class Resolver:
BOOL_OPS = {
BinOp.EQ,
BinOp.NE,
BinOp.LT,
BinOp.GT,
BinOp.LE,
BinOp.GE,
}
def instantiate_generic(self, span: SrcSpan, scope: Scope, gen: LType, params: list[LType]):
if not isinstance(gen, LTypeGeneric):
ErrorReporter.report(f"Cannot instantiate non-generic type '{gen}'.", span)
return LTypeError()
if len(params) != len(gen.params):
ErrorReporter.report(
f"Mismatching number of type parameters. Expected {len(gen.params)}, got {len(params)}.",
span
)
return LTypeError()
tys: list[LType] = []
m: dict[str, int] = {}
for i, (t, p) in enumerate(zip(params, gen.params)):
tys.append(t)
m[p.name] = i
return self.substitute_types(gen.sub, m, tys)[0]
def resolve_type(self, scope: Scope, t: TypeAst) -> LType:
match t:
case TypeAstPtr(span, to):
return LTypePtr(self.resolve_type(scope, to))
case TypeAstInst(span, generic, params):
rgen = self.resolve_type(scope, generic)
rparams = [self.resolve_type(scope, t) for t in params]
return self.instantiate_generic(span, scope, rgen, rparams)
case TypeAstPath(span, root, names):
assert len(names) > 0
if root is None:
n = names[0]
match n:
case "uint8": return LTypeInt(8, False)
case "uint16": return LTypeInt(16, False)
case "uint32": return LTypeInt(32, False)
case "uint64": return LTypeInt(64, False)
case "int8": return LTypeInt(8, True)
case "int16": return LTypeInt(16, True)
case "int32": return LTypeInt(32, True)
case "int64": return LTypeInt(64, True)
case "bool": return LTypeBool()
case "void": return LTypeVoid()
case _: pass
current = scope
else:
current = self.resolve_type(scope, root)
for i, name in enumerate(names):
if isinstance(current, Scope):
current = current.get_type(name, span)
elif i != len(names) - 1:
ErrorReporter().report(f"Not a scope: {current}.", span)
assert isinstance(current, LType)
return current
case _:
raise NotImplementedError(f"resolving {t} is not implemented yet.")
def resolve_struct(self, module: ModuleScope, p: AstStruct) -> None:
def get_sub(scope: Scope):
return LTypeStruct(
p.name,
[(lambda t, f=field: (f.name, self.resolve_type(t, f.ty))) for field in p.ty.fields],
scope
)
if p.ty_params is not None:
module.types[p.name] = LTypeGeneric(module, p.ty_params, get_sub)
else:
module.types[p.name] = get_sub(module)
def resolve_func(self, mod: ModuleScope, p: AstFunc) -> None:
mod.funcs[p.name] = LFunc(
module = mod,
name = p.name,
span = p.span,
ret_ty = self.resolve_type(mod, p.ret) if p.ret is not None else None,
arg_types = [self.resolve_type(mod, a.ty) for a in p.args],
c_varargs = p.c_varargs,
arg_names = [a.name for a in p.args],
attrs = p.attrs,
body = p.body,
resolver = self
)
def resolve_stmt(self, scope: LocalScope, e: StmtAst) -> StmtRst:
match e:
case StmtDeclAst(span, decl):
t = self.resolve_type(scope, decl.ty) if decl.ty is not None else None
v = self.resolve_expr(scope, decl.value, t) if decl.value is not None else None
if decl.ty is not None:
assert t is not None
elif v is not None:
t = v.ty
else:
ErrorReporter.report("Cannot infer variable type.", span)
t = LTypeError()
var = scope.add_local(t, decl.name, span)
return StmtDeclRst(
span,
name = decl.name,
ty = t,
id = var.id,
val = v
)
case StmtExprAst(span, expr):
return StmtExprRst(span, self.resolve_expr(scope, expr, None))
case _:
raise NotImplementedError(f"Resolving {e} is not implemented yet.")
def can_implicit_cast(self, a: LType, b: LType) -> str | None:
if isinstance(a, LTypeInt) and isinstance(b, LTypeInt):
if a.bits > b.bits:
return f"Cannot implicitly truncate from {a} to {b}. Use an explicit `as` cast."
if b.signed and not a.signed:
return f"Cannot implicitly cast from unsigned {a} to signed {b}. Use an explicit `as` cast."
return None
return f"Mismatching types. Expected {b}, got {a}."
def resolve_expr(self, scope: FuncScope, e: ExprAst, expected_type: LType | None) -> ExprRst:
r = self._resolve_expr(scope, e, expected_type)
if expected_type is not None:
if r.ty == expected_type:
return r
if (err := self.can_implicit_cast(r.ty, expected_type)) is not None:
ErrorReporter.report(err, r.span, kind='note')
else:
return ExprRstCast(r.span, expected_type, r)
return r
def _resolve_expr(self, scope: FuncScope, e: ExprAst, expected_type: LType | None) -> ExprRst:
match e:
case ExprAstInt(span, v):
l = v.bit_length()
if l > 64:
ErrorReporter.report("Integer too large.", span)
return ExprRstInt(span, LTypeInt(64, False), v & 0xFFFFFFFFFFFFFFFF)
if isinstance(expected_type, LTypeInt):
if expected_type.bits <= expected_type.bits:
return ExprRstInt(span, expected_type, v)
if l > 63: return ExprRstInt(span, LTypeInt(64, False), v)
if l > 32: return ExprRstInt(span, LTypeInt(64, True), v)
if l > 31: return ExprRstInt(span, LTypeInt(32, False), v)
else: return ExprRstInt(span, LTypeInt(32, True), v)
case ExprAstStr(span, value):
st = scope.get_type('slice', span)
st2 = self.instantiate_generic(span, scope, st, [LTypeInt(8, False)])
if not isinstance(st2, LTypeStruct):
ErrorReporter.report("slice[uint8] is not a struct.", span)
return ExprRst(span, LTypeError())
return ExprRstStr(span, st2, value.encode('utf-8'))
case ExprAstCast(span, value, ty):
return ExprRstCast(span, self.resolve_type(scope, ty), self.resolve_expr(scope, value, None))
case ExprAstBool(span, v):
return ExprRstInt(span, LTypeBool(), v)
case ExprAstItem(span, of, name):
rof = self.resolve_expr(scope, of, None)
if not isinstance(rof.ty, LTypeStruct):
ErrorReporter.report(f"Cannot get item '{name}' from non-struct {rof.ty}.", span)
return ExprRst(span, LTypeError())
v = rof.ty.get_var(rof, name, span)
return ExprRstVar(span, v.ty, v)
case ExprAstId(span, names):
current = scope
for i, name in enumerate(names):
if isinstance(current, Scope):
current = current.get_var(name, span)
elif i != len(names) - 1:
ErrorReporter().report(f"Not a scope: {current}.", span)
assert isinstance(current, LVar)
return ExprRstVar(span, current.ty, current)
case ExprAstError(span): return ExprRst(span, LTypeError())
case ExprAstBinOp(span, lhs, rhs, op):
rlhs, rrhs = self.resolve_expr(scope, lhs, None), self.resolve_expr(scope, rhs, None)
if op == BinOp.ASSIGN:
match rlhs:
case ExprRstVar(_, ty, var):
return ExprRstAssign(span, ty, var, rrhs)
case _:
ErrorReporter.report("Left operand of assignment is not an lvalue.", rlhs.span)
return ExprRst(span, LTypeError())
if not isinstance(rlhs.ty, LTypeInt):
ErrorReporter.report("Left operand is not an integer.", rlhs.span)
return ExprRst(span, LTypeError())
if not isinstance(rrhs.ty, LTypeInt):
ErrorReporter.report("Right operand is not an integer.", rrhs.span)
return ExprRst(span, LTypeError())
if ty := rlhs.ty.extend_to(rrhs.ty): pass
elif ty := rrhs.ty.extend_to(rlhs.ty): pass
else: ty = rlhs.ty
if op in Resolver.BOOL_OPS:
rty = LTypeBool()
else:
rty = ty
return ExprRstBinOp(span, rty, rlhs, rrhs, op, ty)
case ExprAstUnOp(span, val, op):
re = self.resolve_expr(scope, val, None)
return ExprRstUnOp(
span,
re.ty,
re,
op,
)
case ExprAstIf(span, cond, true, false):
rcond = self.resolve_expr(scope, cond, LTypeBool())
if not isinstance(rcond.ty, LTypeBool):
ErrorReporter.report(f"Condition must be of type bool, got {rcond.ty}.", rcond.span)
rtrue = self.resolve_expr(scope, true, expected_type)
if false is None:
if not isinstance(rtrue.ty, LTypeVoid):
ErrorReporter.report(
"The if expression doesn't have an else branch, "
"so the return type must be void.",
span
)
ty = LTypeVoid()
rfalse = None
else:
rfalse = self.resolve_expr(scope, false, rtrue.ty)
if rfalse.ty != rtrue.ty:
ErrorReporter.report(
"Mismatching branch types. "
f"True branch is {rtrue.ty}, but false branch is {rfalse.ty}",
span
)
ty = rtrue.ty
return ExprRstIf(
span,
ty,
rcond,
rtrue,
rfalse
)
case ExprAstBlock(span, stmts, tail):
nscope = BlockScope(scope.func, scope, {}, {})
rstmts = [self.resolve_stmt(nscope, s) for s in stmts]
rtail = self.resolve_expr(nscope, tail, expected_type) if tail is not None else None
return ExprRstBlock(
span,
rtail.ty if rtail is not None else LTypeVoid(),
stmts = rstmts,
tail = rtail
)
case ExprAstCons(span, ty, fields):
rty = self.resolve_type(scope, ty)
if not isinstance(rty, LTypeStruct):
ErrorReporter.report(f"Cannot construct non-struct type {rty}.", span)
return ExprRst(span, LTypeError())
rfields: list[ExprRst | None] = [None] * len(rty.fields)
for name, value in fields:
fld = rty.get_field(name, span)
if fld is None:
continue
rval = self.resolve_expr(scope, value, fld.ty)
if rval.ty != fld.ty:
ErrorReporter.report(
f"Incorrect type for field '{fld.name}'. Expected {fld.ty}, got {rval.ty}.",
rval.span
)
rfields[fld.index] = rval
if None in rfields:
ErrorReporter.report("Not all fields initialized.", span)
return ExprRst(span, LTypeError())
return ExprRstCons(span, rty, typing.cast(list[ExprRst], rfields))
case ExprAstCall(span, func, args, named_args):
if len(named_args) > 0:
ErrorReporter.report("Named arguements are not yet supported.", span)
return ExprRst(span, LTypeError())
rfunc = self.resolve_expr(scope, func, None)
ret_type = LTypeError()
if not isinstance(rfunc.ty, LTypeFunc):
ErrorReporter.report(f"Cannot call non-function {rfunc.ty}.", span)
rargs = [self.resolve_expr(scope, arg, None) for arg in args]
else:
if rfunc.ty.c_varargs:
if len(args) < len(rfunc.ty.args):
ErrorReporter.report(
f"Not enough arguments provided. Expected at least {len(rfunc.ty.args)}, got {len(args)}.",
span
)
rargs = [self.resolve_expr(scope, arg, t) for t, arg in zip(rfunc.ty.args, args)]
for arg in args[len(rfunc.ty.args):]:
rargs.append(self.resolve_expr(scope, arg, None))
else:
if len(args) != len(rfunc.ty.args):
ErrorReporter.report(
f"Mismatching number of arguments. Expected {len(rfunc.ty.args)}, got {len(args)}.",
span
)
rargs = [self.resolve_expr(scope, arg, t) for t, arg in zip(rfunc.ty.args, args)]
ret_type = rfunc.ty.ret
for i, (narg, argt) in enumerate(zip(rargs, rfunc.ty.args)):
if narg.ty != argt: # TODO: compatible with
ErrorReporter.report(
f"Incorrect type of argument #{i + 1}. Expected {argt}, got {narg.ty}.",
narg.span
)
return ExprRstCall(
span,
ret_type,
func = rfunc,
args = rargs
)
case _:
raise NotImplementedError(f"resolving {e} is not implemented yet.")
def substitute_types(self, orig: LType, m: dict[str, int], tys: list[LType]) -> tuple[LType, bool]:
match orig:
case LTypeOpaque(name) if name in m:
return tys[m[name]], True
case LTypePtr(to):
t, r = self.substitute_types(to, m, tys)
return LTypePtr(t), r
case LTypeArray(of, l):
t, r = self.substitute_types(of, m, tys)
return LTypeArray(t, l), r
case LTypeStruct():
fs: list[typing.Callable[[LTypeStruct], tuple[str, LType]]] = []
did_sub = False
for field in orig.fields:
t, r = self.substitute_types(field.ty, m, tys)
did_sub = did_sub or r
fs.append(lambda _, t=t, f=field: (f.name, t))
if not did_sub: return orig, False
s = LTypeStruct(
f"{orig.name}[{", ".join(map(str, tys))}]",
fs,
orig.parent,
tys
), True
return s
case _: return orig, False
class Llvm:
pass
@dataclasses.dataclass
class LlvmInt(Llvm):
val: int
def __str__(self): return f"{self.val}"
@dataclasses.dataclass
class LlvmVoid(Llvm):
def __str__(self): return "void"
@dataclasses.dataclass
class LlvmVal(Llvm):
val: int
def __str__(self): return f"%t{self.val}"
@dataclasses.dataclass
class LlvmNamedVal(Llvm):
name: str
def __str__(self): return f"%_n_{self.name}"
@dataclasses.dataclass
class LlvmLocalVarVal(Llvm):
id: int
def __str__(self): return f"%_loc_{self.id}"
@dataclasses.dataclass
class LlvmPrivateGlobal(Llvm):
id: int
def __str__(self): return f"@.{self.id}"
@dataclasses.dataclass
class LlvmGlobal(Llvm):
name: str
def __str__(self): return f"@{self.name}"
class CodeGen:
def __init__(self, is_le: bool):
self._output = bytearray()
self._indent_level = 0
self._label_no = 0
self._val_no = 0
self._is_le = is_le
self._struct_tys: set[str] = set()
def _write_pre(self, s: str):
self._output = s.encode() + b"\n" + self._output
def _write(self, s: str, indent: int = 0):
self._output += b" " * (self._indent_level + indent)
self._output += s.encode()
self._output += b"\n"
def _newpriv_global(self) -> LlvmPrivateGlobal:
v = LlvmPrivateGlobal(self._val_no)
self._val_no += 1
return v
def _newval(self) -> LlvmVal:
v = LlvmVal(self._val_no)
self._val_no += 1
return v
def _newlab(self) -> str:
l = f"l{self._label_no}"
self._label_no += 1
return l
def _label(self, name: str) -> str:
self._write(name + ':', -1)
return name
def _newvaleq(
self,
instr: str,
*args: Llvm | int | str,
surround = ' ',
skip_assign = False
) -> Llvm:
v = self._newval() if not skip_assign else LlvmVoid()
a = f"{v} = " if not skip_assign else ""
self._write(f"{a}{instr}{surround[0:1]}{', '.join(map(str, args))}{surround[1:2]}")
return v
def _indent(self): self._indent_level += 1
def _dedent(self): self._indent_level -= 1
def _ir_struct_ty_lit(self, t: LTypeStruct) -> str:
return f"{{ {', '.join(self._ir_ty(f.ty) for f in t.fields)} }}"
def _ir_struct_ty(self, t: LTypeStruct) -> str:
name = t.mangled_name()
if name not in self._struct_tys:
self._write_pre(f"%struct.{name} = type {self._ir_struct_ty_lit(t)}")
self._struct_tys.add(name)
return f"%struct.{name}"
def _ir_func_ty(self, t: LTypeFunc, name = "") -> str:
if name: name = " " + name + " "
va = ", ..." if t.c_varargs else ""
return f"{self._ir_ty(t.ret)}{name}({", ".join(map(self._ir_ty, t.args))}{va})"
def _ir_ty(self, t: LType) -> str:
match t:
case LTypeArray(of, len): return f"[{of} x {len}]"
case LTypePtr(): return "ptr"
case LTypeFunc(): return self._ir_func_ty(t)
case LTypeInt(bits): return f"i{bits}"
case LTypeStruct(): return self._ir_struct_ty(t)
case LTypeVoid(): return "void"
case _: raise NotImplementedError(f"_ir_ty({t})")
def _ir_binop(self, op: BinOp, signed: bool, floating: bool) -> str:
'''BinOp to IR instruction name.'''
sign = "" if floating else ("s" if signed else "u")
tyi = "f" if floating else ""
ty = "f" if floating else "i"
match op:
case BinOp.ADD: return f"{tyi}add"
case BinOp.SUB: return f"{tyi}sub"
case BinOp.MUL: return f"{tyi}mul"
case BinOp.DIV: return f"{tyi}{sign}div"
case BinOp.MOD: return f"{tyi}{sign}rem"
case BinOp.EQ: return f"{ty}cmp eq"
case BinOp.NE: return f"{ty}cmp ne"
case BinOp.LT: return f"{ty}cmp {sign}lt"
case BinOp.GT: return f"{ty}cmp {sign}gt"
case BinOp.LE: return f"{ty}cmp {sign}le"
case BinOp.GE: return f"{ty}cmp {sign}ge"
case _: raise NotImplementedError(f"_ir_binop({op}, {signed})")
def _gen_int_cast(self, a: LTypeInt, b: LTypeInt, v: Llvm) -> Llvm:
'''Generate integer conversion instructions. (maybe generate none)'''
if a.bits == b.bits: return v
if a.bits < b.bits:
ins = f"{'s' if b.signed else 'z'}ext {self._ir_ty(a)} {v} to {self._ir_ty(b)}"
return self._newvaleq(ins)
# a.bits > b.bits
return self._newvaleq(f"trunc {self._ir_ty(a)} {v} to {self._ir_ty(b)}")
def _gen_store_move(self, ty: LType, val: Llvm, to: Llvm):
self._write(f"store {self._ir_ty(ty)} {val}, ptr {to}")
def gen_move(self, e: ExprRst, to_ty: LType, to: Llvm):
'''Generate move of `e` to address `to`'''
match e:
case ExprRstInt(_, ty, value):
assert isinstance(ty, LTypeInt)
self._gen_store_move(ty, LlvmInt(value), to)
case ExprRstBlock(_, ty, stmts, tail):
for stmt in stmts: self.gen_stmt(stmt)
assert tail is not None
self.gen_move(tail, to_ty, to)
case _:
v = self.gen_expr(e)
assert e.ty == to_ty
self._gen_store_move(e.ty, v, to)
def gen_stmt(self, e: StmtRst):
match e:
case StmtExprRst(_, expr):
self.gen_expr(expr)
case StmtDeclRst(_, ty, name, id, val):
v = LlvmLocalVarVal(id)
self._write(f"{v} = alloca {self._ir_ty(ty)} ; {name}: {ty}")
if val is not None:
self.gen_move(val, ty, v)
def _ir_str_const(self, s: bytes) -> str:
return "c\"" + "".join(chr(c) if 32 <= c <= 127 else f"\\{c:02X}" for c in s) + "\""
def _gen_init_struct(self, ty: LTypeStruct, vs: typing.Iterable[Llvm | int | str]) -> Llvm:
'''Returns pointer to struct.'''
sty = self._ir_struct_ty(ty)
p = self._newvaleq("alloca", sty, f"align {ty.align()}")
cur = self._newvaleq("load", sty, f"ptr {p}")
for i, (v, f) in enumerate(zip(vs, ty.fields)):
cur = self._newvaleq("insertvalue", f"{sty} {cur}", f"{self._ir_ty(f.ty)} {v}", i)
return cur
def gen_expr(self, e: ExprRst) -> Llvm:
match e:
case ExprRstInt(_, ty, value): return LlvmInt(value)
case ExprRstStr(_, ty, value):
assert isinstance(ty, LTypeStruct)
g = self._newpriv_global()
self._write_pre(
f"{g} = private unnamed_addr constant [{len(value)} x i8] " +
self._ir_str_const(value)
)
return self._gen_init_struct(ty, [len(value), g])
case ExprRstBlock(_, ty, stmts, tail):
for stmt in stmts:
self.gen_stmt(stmt)
return self.gen_expr(tail) if tail else LlvmVoid()
case ExprRstVar(_, ty, var):
match var:
case LArgVar(ty, name): return LlvmNamedVal(f"arg_{name}")
case LStructFieldVar(ty, name, of, _, index):
v = self.gen_expr(of)
return self._newvaleq("extractvalue", f"{self._ir_ty(of.ty)} {v}", index)
# case LGlobalVar(ty, name): return LlvmGlobalVal(name)
case LLocalVar(ty, name, id):
v = self._newval()
self._write(f"{v} = load {self._ir_ty(ty)}, ptr {LlvmLocalVarVal(id)}")
return v
case _: raise NotImplementedError(f"gen_expr(..)/ExprRstVar(_, _, {var})")
case ExprRstIf(_, ty, cond, true, false):
ret_lab = self._newlab()
true_lab = self._newlab()
false_lab = self._newlab() if false is not None else ret_lab
cond_v = self.gen_expr(cond)
# cond_v_i1 = self._newvaleq(f"trunc {self._ir_ty(cond.type)} {cond_v} to i1")
self._write(f"br i1 {cond_v}, label %{true_lab}, label %{false_lab}")
self._label(true_lab)
true_v = self.gen_expr(true)
self._write(f"br label %{ret_lab}")
false_v = None
if false is not None:
self._label(false_lab)
false_v = self.gen_expr(false)
self._write(f"br label %{ret_lab}")
self._label(ret_lab)
if not isinstance(ty, LTypeVoid) and false_v is not None:
# <=> false is not None
return self._newvaleq(
f"phi {self._ir_ty(ty)}",
f"[{true_v}, %{true_lab}]",
f"[{false_v}, %{false_lab}]"
)
return LlvmVoid()
case ExprRstAssign(_, ty, var, val):
match var:
case LLocalVar(_, name, id):
v = LlvmLocalVarVal(id)
self.gen_move(val, ty, v)
return v
case _: raise NotImplementedError(f"gen_expr(..)/ExprRstAssign(_, _, {var}, _)")
case ExprRstCast(_, ty, val):
rval = self.gen_expr(val)
if isinstance(ty, LTypeInt) and isinstance(val.ty, LTypeInt):
return self._gen_int_cast(val.ty, ty, rval)
else:
raise NotImplementedError(f"gen_expr(..)/ExprRstCast(_, {ty}, {val})")
case ExprRstBinOp(_, ty, lhs, rhs, op, op_ty):
l, r = self.gen_expr(lhs), self.gen_expr(rhs)
assert isinstance(ty, LTypeInt)
assert isinstance(op_ty, LTypeInt)
assert isinstance(lhs.ty, LTypeInt)
assert isinstance(rhs.ty, LTypeInt)
l2 = self._gen_int_cast(lhs.ty, op_ty, l)
r2 = self._gen_int_cast(lhs.ty, op_ty, r)
op = self._ir_binop(op, ty.signed, False)
v = self._newvaleq(op, f"{self._ir_ty(op_ty)} {l2}", r2)
# if op.startswith("icmp"):
# return self._newvaleq(f"zext i1 {v} to i32")
return v
case ExprRstCons(_, ty, fields):
return self._gen_init_struct(ty, map(self.gen_expr, fields))
case ExprRstCall(_, ty, func, args):
assert isinstance(func, ExprRstVar)
fn = func.var
assert isinstance(fn, LFunc)
if fn.ty.c_varargs:
fargs = (f"{self._ir_ty(arg.ty)} {self.gen_expr(arg)}" for arg in args)
else:
fargs = (f"{self._ir_ty(ty)} {self.gen_expr(arg)}" for arg, ty in zip(args, fn.ty.args))
return self._newvaleq(
f"call ccc {self._ir_ty(fn.ty)} @{fn.name}",
*fargs,
surround="()",
skip_assign = fn.ty.ret == LTypeVoid()
)
case _: raise NotImplementedError(f"gen_expr({e})")
def gen_func(self, p: LFunc):
s = ""
no_body = False
if p.body is None:
s += "declare "
no_body = True
else:
s += "define "
if "export" in p.attrs: s += "external "
s += "ccc "
s += f"{self._ir_ty(p.ty.ret)} "
s += f"@{p.name}("
s += ", ".join(f"{self._ir_ty(a)} %_n_arg_{n}" for n, a in zip(p.arg_names, p.ty.args))
if p.ty.c_varargs: s += ", ..."
s += ")"
if no_body:
self._write(s)
return
s += " {"
self._write(s)
self._indent()
self._label(self._newlab())
assert p.body is not None
v = self.gen_expr(p.body)
if not isinstance(v, LlvmVoid):
self._write(f"ret {self._ir_ty(p.ty.ret)} {v}")
else:
self._write("ret void")
self._dedent()
self._write("}")
def replace_in_list[T](l: list[T], a: T, b: T) -> list[T]:
'''Replace `a` with `b` in list `l`, mutating it. Return the updated list.'''
for i, e in enumerate(l):
if e == a: l[i] = b
return l
def cli_build(args: argparse.Namespace) -> str:
in_file: typing.TextIO = args.file
inp = InputStr(in_file.read(), fname = in_file.name)
in_file.close()
inp_id = InputManager.add(inp)
toks = list(tokenize(inp_id, inp.s))
par = Parser(toks)
mod = ModuleScope("main")
gen = CodeGen(True)
res = Resolver()
while not par._is(TokenKind.EOF):
try:
n = par.parse_toplevel()
if ErrorReporter.report_count > 0:
exit(1)
if args.debug:
pprint.pprint(n)
match n:
case AstFunc():
res.resolve_func(mod, n)
if ErrorReporter.report_count > 0:
exit(1)
gen.gen_func(mod.funcs[n.name])
case AstStruct():
res.resolve_struct(mod, n)
if ErrorReporter.report_count > 0:
exit(1)
except Parser.FailedMiserably:
pass
except IrrecoverableError:
exit(1)
if ErrorReporter.report_count > 0:
exit(1)
if args.debug:
sys.stderr.buffer.write(gen._output)
with subprocess.Popen(
replace_in_list(shlex.split(args.llc_cmd), '%', str(max(0, min(args.opt_level, 3)))),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
stdin=subprocess.PIPE
) as n:
assert n.stdin is not None
llc_output, stderr = n.communicate(gen._output)
if n.returncode != 0:
ErrorReporter.report("Failed to run LLC", None)
sys.stderr.buffer.write(stderr)
exit(1)
else:
sys.stderr.buffer.write(stderr)
if args.debug:
sys.stderr.buffer.write(llc_output)
bin_name = in_file.name.removesuffix(".qq")
with subprocess.Popen(
replace_in_list(shlex.split(args.asm_cmd), '%', bin_name),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
stdin=subprocess.PIPE
) as n:
assert n.stdin is not None
_, stderr = n.communicate(llc_output)
if n.returncode != 0:
ErrorReporter.report("Failed to run assembler", None)
sys.stderr.buffer.write(stderr)
exit(1)
else:
sys.stderr.buffer.write(stderr)
if ErrorReporter.report_count > 0:
exit(1)
ErrorReporter.report("", None, kind='success', title=f"Compiled {bin_name} successfully!")
return bin_name
def cli_run(args: argparse.Namespace):
bin_name = cli_build(args)
sys.stdout.flush()
sys.stderr.flush()
exit(subprocess.run([os.path.abspath(bin_name), *args.rest]).returncode)
def cli():
parser = argparse.ArgumentParser(prog='qqc')
parser.add_argument('-d', '--debug', action='store_true')
subparsers = parser.add_subparsers(help="Commands:", required=True)
common_build_parser = argparse.ArgumentParser(add_help=False)
common_build_parser.add_argument(
'--use-asm',
type=str,
help="Assembler command to use. '%%' is replaced with the output file.",
default="clang -fPIC libqqrt.o -xassembler -fuse-ld=lld -o % -",
dest="asm_cmd"
)
common_build_parser.add_argument(
'--use-llc',
type=str,
help="LLC command to use. '%%' is replaced with the opt level. "
"Input is on the stdin, output is read from stdout.",
default="llc -relocation-model=pic -O %",
dest="llc_cmd"
)
common_build_parser.add_argument(
'-O',
type=int,
help="Optimization level. 1, 2, or 3. Default is 2.",
default=2,
dest="opt_level"
)
common_build_parser.add_argument('file', type=argparse.FileType("r", encoding='utf-8'), help="Input file")
build_parser = subparsers.add_parser('build', help="Compile file.", parents=[common_build_parser])
build_parser.set_defaults(func=cli_build)
run_parser = subparsers.add_parser('run', help="Compile and run file.", parents=[common_build_parser])
run_parser.set_defaults(func=cli_run)
run_parser.add_argument('rest', nargs=argparse.REMAINDER, help="Arguments passed to output program.")
args = parser.parse_args()
args.func(args)
if __name__ == '__main__': cli()
@monomere
Copy link
Author

monomere commented Jan 21, 2024

Requirements

  • LLVM (llc specifically)
  • Assembler (Clang by default)
  • libqqrt.o - runtime library

Runtime library

Should provide

  • void qq_write_uint8s(uint64_t fout, struct qq_slice_uint8 i)
  • void qq_write_int32(uint64_t fout, int32_t i)
    • print an int and a newline to specified file.
  • int32_t qq_read_int32(uint64_t fout)
    • read an int from the specified file (akin to int(input()) in python).
  • uint64_t qq_stdout()
  • uint64_t qq_stderr()
  • uint64_t qq_stdin()

Possible C code

#include <stdio.h>
#include <stdint.h>
#include <stddef.h>
#include <inttypes.h>
extern uint64_t qq_stdout() { return (uint64_t)(uintptr_t)stdout; }
extern uint64_t qq_stderr() { return (uint64_t)(uintptr_t)stderr; }
extern uint64_t qq_stdin() { return (uint64_t)(uintptr_t)stdin; }
extern void qq_write_int32(uint64_t fout, int32_t i) {
  FILE *pfout = (void*)(uintptr_t)fout;
  fprintf(pfout, "%" PRId32, i);
}
struct qq_slice_uint8 {
  uint64_t len;
  uint8_t *data;
};
extern void qq_write_uint8s(uint64_t fout, struct qq_slice_uint8 s) {
  FILE *pfout = (void*)(uintptr_t)fout;
  fwrite(s.data, s.len, 1, pfout);
}
extern int32_t qq_read_int32(uint64_t fin) {
  FILE *pfin = (void*)(uintptr_t)fin;
  int32_t r = 0;
  fscanf(pfin, " %" PRId32, &r);
  return r;
}

Example code

extern func qq_read_int32(o: uint64): int32;
extern func qq_stdin(): uint64;
extern func read_int32() = qq_read_int32(qq_stdin());
struct slice[T] {
  len: uint64,
  ptr: *T
}
func fib(i: uint64): uint64 =
  if i <= 1 { 1 }
  else { fib(i - 1) + fib(i - 2) };
struct ProgramResults[T] {
  index: T,
  result: T
}
func run_program(): ProgramResults[uint64] {
  let inp: uint64 = read_int32();
  let res = fib(inp);
  ProgramResults::[uint64]:(
    index = inp,
    result = res
  )
}
extern func printf(s: *uint8, c_varargs): void;
export func main(argc: int32): int32 {
  let r = run_program();
  printf("The fibonacci number at index %lu is %lu.\n\0".ptr, r.index, r.result);
  0
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment