Last active
March 24, 2025 15:06
-
-
Save monomere/63e917f01e71ebb51dd86fd2f2b22236 to your computer and use it in GitHub Desktop.
Tiny compiler written in python that generates text LLVM IR.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# License: MIT | |
# Author: monomere | |
# Changelog: | |
# Note: With every update, the doc comment is almost always updated | |
# to reflect new additions/improvements to the language. | |
# Update 7.1 - 2024/09/13: | |
# - Fix variadic arguments not being passed. | |
# Update 7 - 2024/01/27: | |
# - Add changelog note. | |
# - Fix bug in parser error handling (didn't `get` non-bracket tokens before) | |
# - Fix bug in `LTypeArray.align()`, it returned `self.of.size()` instead of `.align()`. | |
# - Add type name mangling. | |
# - Add type generics. | |
# - Change "Changes" to "Changelog". | |
# - `CodeGen._ir_ty(LTypeFunc)` works. `CodeGen._ir_func_ty(t, name="")` | |
# - Add C varargs support. Use `c_varargs` as last argument to specify varargs. | |
# Example: `extern func printf(s: *uint8, c_varargs): void;` | |
# Update 6 - 2024/01/26: | |
# - Add `qq_write_uint8s` to write strings. | |
# - Fix parsing optional type annotations in let statements. | |
# - Add named argument support (parsing). | |
# - Now code is compiled with PIC by default. | |
# - Add basic string literals. | |
# Requires a `struct uint8s { len: uint64, data: *uint8 }` type to be present. | |
# - Fix type path resolve bug (did first part twice). | |
# - Add structs. | |
# - Add pointer types (parsing and resolving). | |
# - Implement better parser error handling. | |
# Now the parser skips to the next top-level item | |
# when it reaches the global scope. TODO: Different bracket types? | |
# Update 5 - 2024/01/23: | |
# - Add implicit casting. | |
# - Fix CodeGen._gen_int_cast. | |
# - Rename `RstExpr.type` to `RstExpr.ty` to be consistent with everything else. | |
# - Add `as` casting. | |
# - Add `qq_read_int32`, `qq_stdin` to libqqrt. | |
# - Rename `qq_write_int` to `qq_write_int32` in libqqrt. | |
# - Fix parser precedence parsing bugs. | |
# - Add support for recursive functions. | |
# - Add support for basic type inference (see `expected_type`) | |
# For example, integer literals can now infer their type. | |
# - Add missing errors. | |
# Update 4 - 2024/01/22: | |
# - A single bool type, since LLVM supports 1-bit ints. | |
# - Fix block parsing bug (last stmt is duplicated into the tail). | |
# - Move from QBE to LLVM. | |
# Update 3 - 2024/01/22: | |
# - Add comparison operators. | |
# - Add if expressions. | |
# - Add boolean types: bool{8,16,32,64}. | |
# Update 2 - 2024/01/22: | |
# - Add assignment expressions. | |
# - Debug flags print the assembly generated by QBE. | |
# - Codegen for moving stuff. Qbe values can now be indirect and non-indirect. | |
# - Change license to MIT. Will maybe change to GPL or something... this needs to be open source. | |
# - Add some comments to codegen code. I still need to add more comments. | |
# - Add local variables. | |
# - Add size and align methods to LType. | |
# - Add additional CLI arguments (--use-asm and --use-qbe). | |
# - Fix void handling in codegen. | |
# Update 1 - 2024/01/21: | |
# - Improve some error messages. | |
# - Fix some parser and error reporter bugs. | |
# - Add down/up casting of integer types in codegen. | |
# - Fix binary operator typechecking (still limits to ints tho). | |
# - Add sized and signed/unsigned unsigned integer types: (u)int{8,16,32,64}. | |
# - Add `export` top-level decl attribute. | |
# - CLI improvement. | |
# - Parse more than one argument for functions. | |
# - Some code readability changes. | |
# - Remove unneeded debug stuff from `parse_expr`. | |
# - Fix formatting of doc comment. | |
# Initial release - 2024/01/21 | |
''' | |
# Requirements | |
- LLVM (llc specifically) | |
- Assembler (Clang by default) | |
- libqqrt.o - runtime library | |
# Runtime library | |
## Should provide | |
- `void qq_write_uint8s(uint64_t fout, struct qq_slice_uint8 i)` | |
- `void qq_write_int32(uint64_t fout, int32_t i)` | |
- print an int and a newline to specified file. | |
- `int32_t qq_read_int32(uint64_t fout)` | |
- read an int from the specified file (akin to `int(input())` in python). | |
- `uint64_t qq_stdout()` | |
- `uint64_t qq_stderr()` | |
- `uint64_t qq_stdin()` | |
## Possible C code | |
```c | |
#include <stdio.h> | |
#include <stdint.h> | |
#include <stddef.h> | |
#include <inttypes.h> | |
extern uint64_t qq_stdout() { return (uint64_t)(uintptr_t)stdout; } | |
extern uint64_t qq_stderr() { return (uint64_t)(uintptr_t)stderr; } | |
extern uint64_t qq_stdin() { return (uint64_t)(uintptr_t)stdin; } | |
extern void qq_write_int32(uint64_t fout, int32_t i) { | |
FILE *pfout = (void*)(uintptr_t)fout; | |
fprintf(pfout, "%" PRId32, i); | |
} | |
struct qq_slice_uint8 { | |
uint64_t len; | |
uint8_t *data; | |
}; | |
extern void qq_write_uint8s(uint64_t fout, struct qq_slice_uint8 s) { | |
FILE *pfout = (void*)(uintptr_t)fout; | |
fwrite(s.data, s.len, 1, pfout); | |
} | |
extern int32_t qq_read_int32(uint64_t fin) { | |
FILE *pfin = (void*)(uintptr_t)fin; | |
int32_t r = 0; | |
fscanf(pfin, " %" PRId32, &r); | |
return r; | |
} | |
``` | |
# Example code | |
```go | |
extern func qq_read_int32(o: uint64): int32; | |
extern func qq_stdin(): uint64; | |
extern func read_int32() = qq_read_int32(qq_stdin()); | |
struct slice[T] { | |
len: uint64, | |
ptr: *T | |
} | |
func fib(i: uint64): uint64 = | |
if i <= 1 { 1 } | |
else { fib(i - 1) + fib(i - 2) }; | |
struct ProgramResults[T] { | |
index: T, | |
result: T | |
} | |
func run_program(): ProgramResults[uint64] { | |
let inp: uint64 = read_int32(); | |
let res = fib(inp); | |
ProgramResults::[uint64]:( | |
index = inp, | |
result = res | |
) | |
} | |
extern func printf(s: *uint8, c_varargs): void; | |
export func main(argc: int32): int32 { | |
let r = run_program(); | |
printf("The fibonacci number at index %lu is %lu.\n\0".ptr, r.index, r.result); | |
0 | |
} | |
``` | |
''' | |
from __future__ import annotations | |
import typing, dataclasses, enum, abc, sys, \ | |
subprocess, argparse, os.path, functools, \ | |
shlex, pprint | |
def print(*args, **kwargs): | |
__builtins__.print(';', *args, **kwargs, file=sys.stderr) | |
class Input(abc.ABC): | |
@abc.abstractmethod | |
def source(self) -> str: | |
... | |
@abc.abstractmethod | |
def filename(self) -> str: | |
... | |
def to_line_col(self, offset: int) -> tuple[str, int, int]: | |
src = self.source() | |
line = 0 | |
col = 0 | |
lstart = 0 | |
for i in range(offset): | |
if i >= len(src): break | |
if src[i] == "\n": | |
lstart = i + 1 | |
line += 1 | |
col = 0 | |
else: | |
col += 1 | |
lend = lstart + 1 | |
while lend < len(src) and src[lend] != "\n": | |
lend += 1 | |
return (src[lstart : lend], line, col) | |
class InputStr(Input): | |
def __init__(self, s: str, fname: str = "<input>"): | |
self.s = s | |
self.fname = fname | |
def source(self) -> str: | |
return self.s | |
def filename(self) -> str: | |
return self.fname | |
@dataclasses.dataclass | |
class SrcSpan: | |
file: int | |
start: int | |
stop: int | |
def clone(self): | |
return SrcSpan(self.file, self.start, self.stop) | |
@dataclasses.dataclass | |
class SrcSpanInfo: | |
as_str: str | |
line_str: str | |
line: int | |
col: int | |
class InputManager: | |
inputs: list[Input] = [] | |
@staticmethod | |
def add(inp: Input) -> int: | |
InputManager.inputs.append(inp) | |
return len(InputManager.inputs) - 1 | |
@staticmethod | |
def span_info(span: SrcSpan) -> SrcSpanInfo: | |
inp = InputManager.inputs[span.file] | |
line_str, line, col = inp.to_line_col(span.start) | |
return SrcSpanInfo( | |
f"{inp.filename()}:{line + 1}:{col + 1}", line_str, line, col | |
) | |
class IrrecoverableError(Exception): | |
pass | |
class ErrorReporter: | |
report_count = 0 | |
ignore = False | |
@staticmethod | |
def report( | |
msg: str, | |
span: SrcSpan | None, *, | |
kind: typing.Literal['error', 'note', 'success'] = 'error', | |
title: str | None = None | |
) -> Exception: | |
if ErrorReporter.ignore: return Exception(f"ignored {kind}: {title}: {msg} @{span}") | |
ErrorReporter.report_count += 1 | |
fout = sys.stderr | |
info = InputManager.span_info(span) if span is not None else None | |
color = { "note": "34", "error": "1;31", "success": "1;32" }[kind] | |
title = { "note": "Note:", "error": "Error:", "success": "Success!" }[kind] \ | |
if title is None else title | |
fout.write(f"\x1b[{color}m{title} ") | |
if info is not None: | |
fout.write("at ") | |
fout.write(info.as_str) | |
fout.write(" ") | |
fout.write(msg) | |
fout.write("\x1b[m\n") | |
if info is not None: | |
assert span is not None | |
pre = f"{info.line + 1} | " | |
fout.write(pre) | |
for c in info.line_str: | |
if c == '\t': fout.write(" ") | |
else: fout.write(c) | |
fout.write("\n") | |
for _ in range(len(pre)): | |
fout.write(" ") | |
for i in range(info.col): | |
if info.line_str[i] == '\t': | |
fout.write(" ") | |
fout.write(" ") | |
fout.write("^") | |
for i in range(span.stop - span.start - 1): | |
if info.col + i >= len(info.line_str) - 1: break | |
if info.line_str[info.col + i] == '\t': | |
fout.write("~") | |
fout.write("~") | |
fout.write("\n") | |
if ErrorReporter.report_count > 15: | |
raise IrrecoverableError() | |
return Exception(msg) | |
class TokenKind(enum.Enum): | |
ERROR = enum.auto() | |
EOF = enum.auto() | |
ID = enum.auto() | |
INT = enum.auto() | |
STR = enum.auto() | |
@dataclasses.dataclass | |
class Token: | |
span: SrcSpan | |
kind: TokenKind | str | |
value: str | int | None = None | |
OPEN_BR = ('{', '[', '(') | |
CLOSE_BR = ('}', ']', ')') | |
def tokenize(file: int, inp: str) -> typing.Iterable[Token]: | |
off = 0 | |
EQ_OPS = ('>', '<', '=', '!') | |
cur_level = 0 # bracket level | |
def iseof() -> bool: | |
return off >= len(inp) | |
def peek(): | |
if iseof(): | |
return "" | |
return inp[off] | |
def get(): | |
nonlocal off | |
c = peek() | |
off += 1 | |
return c | |
def isdigit(c: str) -> bool: | |
return c.isdigit() or (ord(c) | 32) in range(97, 103) # a-f | |
def digitof(c: str) -> int: | |
o = ord(c) | |
if o < 65: # c < 'A', 0-9 | |
return o - 48 # c - '0' | |
else: | |
return 10 + (o | 32) - 97 # lowercase c - 'a' + 10 | |
while not iseof(): | |
while peek().isspace(): | |
off += 1 | |
if peek().isalpha(): | |
start = off | |
cs = [get()] | |
while peek().isalpha() or peek().isdigit() or peek() == '_': | |
cs.append(get()) | |
yield Token(SrcSpan(file, start, off), TokenKind.ID, "".join(cs)) | |
elif peek() == '"': | |
start = off | |
get() | |
cs = [] | |
while not iseof() and peek() != '"': | |
if peek() == '\\': | |
get() | |
e = get() | |
match e: | |
case 'n': c = '\n' | |
case 't': c = '\t' | |
case 'r': c = '\r' | |
case '0': c = '\0' | |
case 'b': c = '\b' | |
case 'e': c = '\033' | |
case _: raise SyntaxError(f"Unknown escape character: '{e}'.") | |
cs.append(c) | |
else: | |
cs.append(get()) | |
if get() != '"': # iseof() -> True | |
raise SyntaxError("Expected string end.") | |
yield Token(SrcSpan(file, start, off), TokenKind.STR, "".join(cs)) | |
elif peek().isdigit(): | |
start = off | |
n = ord(get()) - ord("0") | |
base = 10 | |
if n == 0 and peek().isalpha(): # 0[xdob] | |
b = get() | |
match b: | |
case "x": | |
base = 16 | |
case "d": | |
base = 10 | |
case "o": | |
base = 8 | |
case "b": | |
base = 2 | |
case _: | |
raise SyntaxError(f"Unknown base: '{b}'.") | |
while peek() and isdigit(peek()): | |
n = n * base + digitof(get()) | |
yield Token(SrcSpan(file, start, off), TokenKind.INT, n) | |
else: | |
start = off | |
c = get() | |
if not c: break | |
if c in EQ_OPS and peek() == '=': | |
get() | |
yield Token(SrcSpan(file, start, start + 1), c + '=') | |
elif c in OPEN_BR: | |
cur_level += 1 | |
yield Token(SrcSpan(file, start, start), c, cur_level) | |
elif c in CLOSE_BR: | |
cur_level = cur_level - 1 if cur_level > 0 else 0 | |
yield Token(SrcSpan(file, start, start), c, cur_level) | |
elif c == ':' and peek() == ':': | |
get() | |
yield Token(SrcSpan(file, start, start + 1), '::', cur_level) | |
else: | |
yield Token(SrcSpan(file, start, start), c) | |
yield Token(SrcSpan(file, off, off), TokenKind.EOF) | |
@dataclasses.dataclass | |
class Ast: | |
span: SrcSpan | |
@dataclasses.dataclass | |
class TypeAst(Ast): | |
pass | |
@dataclasses.dataclass | |
class TypeAstParam(TypeAst): | |
name: str | |
@dataclasses.dataclass | |
class TypeAstPtr(TypeAst): | |
to: TypeAst | |
@dataclasses.dataclass | |
class TypeAstPath(TypeAst): | |
root: TypeAst | None | |
names: list[str] | |
@dataclasses.dataclass | |
class TypeAstInst(TypeAst): | |
generic: TypeAst | |
params: list[TypeAst] | |
@dataclasses.dataclass | |
class TypeAstStruct(Ast): | |
attrs: set[str] | |
fields: list[AstDecl[TypeAst]] | |
TT = typing.TypeVar('TT', covariant=True) | |
@dataclasses.dataclass | |
class AstDecl(Ast, typing.Generic[TT]): | |
attrs: list[str] | |
name: str | |
ty: TT | |
value: ExprAst | None | |
@dataclasses.dataclass | |
class ExprAst(Ast): | |
pass | |
@dataclasses.dataclass | |
class ExprAstInt(ExprAst): | |
value: int | |
@dataclasses.dataclass | |
class ExprAstStr(ExprAst): | |
value: str | |
@dataclasses.dataclass | |
class ExprAstBool(ExprAst): | |
value: bool | |
@dataclasses.dataclass | |
class ExprAstError(ExprAst): | |
pass | |
class BinOp(enum.Enum): | |
ADD = enum.auto() | |
SUB = enum.auto() | |
MUL = enum.auto() | |
DIV = enum.auto() | |
MOD = enum.auto() | |
EQ = enum.auto() | |
NE = enum.auto() | |
LT = enum.auto() | |
GT = enum.auto() | |
LE = enum.auto() | |
GE = enum.auto() | |
ASSIGN = enum.auto() | |
@dataclasses.dataclass | |
class ExprAstBinOp(ExprAst): | |
lhs: ExprAst | |
rhs: ExprAst | |
op: BinOp | |
class UnOp(enum.Enum): | |
POS = enum.auto() | |
NEG = enum.auto() | |
NOT = enum.auto() | |
BIN_NOT = enum.auto() | |
@dataclasses.dataclass | |
class ExprAstUnOp(ExprAst): | |
val: ExprAst | |
op: UnOp | |
@dataclasses.dataclass | |
class ExprAstId(ExprAst): | |
names: list[str] | |
@dataclasses.dataclass | |
class ExprAstItem(ExprAst): | |
of: ExprAst | |
name: str | |
@dataclasses.dataclass | |
class ExprAstCall(ExprAst): | |
func: ExprAst | |
args: list[ExprAst] | |
named_args: list[tuple[str, ExprAst]] | |
@dataclasses.dataclass | |
class ExprAstCons(ExprAst): | |
ty: TypeAst | |
fields: list[tuple[str, ExprAst]] | |
@dataclasses.dataclass | |
class ExprAstCast(ExprAst): | |
expr: ExprAst | |
ty: TypeAst | |
@dataclasses.dataclass | |
class ExprAstBlock(ExprAst): | |
stmts: list[StmtAst] | |
tail: ExprAst | None | |
@dataclasses.dataclass | |
class ExprAstIf(ExprAst): | |
cond: ExprAst | |
true: ExprAst | |
false: ExprAst | None | |
@dataclasses.dataclass | |
class StmtAst(Ast): | |
pass | |
@dataclasses.dataclass | |
class StmtExprAst(StmtAst): | |
expr: ExprAst | |
@dataclasses.dataclass | |
class StmtDeclAst(StmtAst): | |
decl: AstDecl[TypeAst | None] | |
@dataclasses.dataclass | |
class AstFunc(Ast): | |
attrs: set[str] | |
name: str | |
args: list[AstDecl[TypeAst]] | |
ret: TypeAst | None | |
body: ExprAst | None | |
c_varargs: bool | |
@dataclasses.dataclass | |
class AstStruct(Ast): | |
name: str | |
ty_params: list[TypeAstParam] | None | |
ty: TypeAstStruct | |
class Parser: | |
@staticmethod | |
def _parser(fn): | |
@functools.wraps(fn) | |
def inner(self: Parser, *args, **kwargs): | |
saved_level = self._last_level | |
self._last_level = saved_level | |
saved_span = self._span | |
self._span = self._peek().span.clone() | |
r = fn(self, *args, **kwargs) | |
saved_span.stop = self._span.stop | |
self._span = saved_span | |
self._last_level = saved_level | |
return r | |
return inner | |
class FailedMiserably(Exception): | |
pass | |
OPS: dict[str, tuple[int, BinOp]] = { | |
'+': (10, BinOp.ADD), | |
'-': (10, BinOp.SUB), | |
'*': (20, BinOp.MUL), | |
'/': (20, BinOp.DIV), | |
'%': (20, BinOp.MOD), | |
'==': (5, BinOp.EQ), | |
'!=': (5, BinOp.NE), | |
'<=': (6, BinOp.LE), | |
'>=': (6, BinOp.GE), | |
'<': (6, BinOp.LT), | |
'>': (6, BinOp.GT), | |
'=': (1, BinOp.ASSIGN), | |
} | |
def __init__(self, toks: list[Token]): | |
self.idx = 0 | |
self.toks = toks | |
self._rep_tok: Token | None = None | |
self._rep_count = 0 | |
self._span: SrcSpan = SrcSpan(-1, -1, -1) # current @_parser span | |
self._current_level = 0 # current "bracket level" | |
self._last_level = 0 # saved _current_level. | |
def _peek(self): return self.toks[self.idx] | |
def _get(self): | |
t = self._peek() | |
if self.idx < len(self.toks) - 1: | |
self.idx += 1 | |
self._span.stop = t.span.stop | |
return t | |
def _is(self, kind: TokenKind | str, allow_eof = False): | |
k = self._peek().kind | |
if kind == TokenKind.EOF: return k == kind | |
if k == TokenKind.EOF and not allow_eof: | |
ErrorReporter.report("EOF not allowed", self._peek().span) | |
raise IrrecoverableError() | |
return k == kind | |
def _is_kw(self, kw: str): return self._is(TokenKind.ID) and self._peek().value == kw | |
def _opt_eat(self, kind: TokenKind | str, allow_eof = False) -> Token | None: | |
if self._is(TokenKind.EOF) and not allow_eof: | |
ErrorReporter.report("EOF not allowed", self._peek().span) | |
raise IrrecoverableError() | |
if self._is(kind): | |
t = self._get() | |
if kind in OPEN_BR or kind in CLOSE_BR: | |
self._current_level = typing.cast(int, t.value) | |
return t | |
return None | |
def _handle_eat_fail(self, t: Token): | |
if t.kind in OPEN_BR or t.kind in CLOSE_BR: | |
assert isinstance(t.value, int) | |
if self._last_level <= t.value: | |
while not self._is(TokenKind.EOF): | |
t = self._peek() | |
if t.kind in OPEN_BR or t.kind in CLOSE_BR: | |
assert isinstance(t.value, int) | |
if t.value == 0: | |
self._get() | |
break | |
self._get() | |
raise Parser.FailedMiserably() # go back to toplevel | |
def _eat(self, kind: TokenKind | str) -> Token | None: | |
if t := self._opt_eat(kind): return t | |
t = self._peek() | |
ErrorReporter.report(f"Expected {kind}.", t.span) | |
self._handle_eat_fail(t) | |
return None | |
def _opt_eat_kw(self, kw: str, allow_eof = False) -> Token | None: | |
if self._is(TokenKind.EOF) and not allow_eof: | |
ErrorReporter.report("EOF not allowed", self._peek().span) | |
return None | |
return self._get() if self._is_kw(kw) else None | |
def _eat_kw(self, kw: str) -> Token | None: | |
if t := self._opt_eat_kw(kw): return t | |
t = self._peek() | |
ErrorReporter.report(f"Expected keyword '{kw}'.", t.span) | |
self._handle_eat_fail(t) | |
return None | |
def _eat_name(self) -> str: | |
if t := self._eat(TokenKind.ID): | |
return typing.cast(str, t.value) | |
self._get() | |
return "<error>" | |
@_parser | |
def parse_type_atom(self) -> TypeAst: | |
name = self._eat_name() | |
n = TypeAstPath(self._span, None, [name]) | |
while self._is('::') or self._is('['): | |
if self._is('['): # generic instantiation | |
n.span = n.span.clone() # make sure the span doesn't change. | |
self._get() | |
params: list[TypeAst] = [] | |
while not self._is(']'): | |
params.append(self.parse_type()) | |
if not self._opt_eat(','): break | |
self._eat(']') | |
n = TypeAstInst(self._span.clone(), n, params) | |
else: | |
if not isinstance(n, TypeAstPath): | |
n = TypeAstPath(self._span.clone(), n, []) | |
self._get() | |
n.names.append(self._eat_name()) | |
n.span = n.span.clone() # make sure it doesn't go anywhere. | |
return n | |
@_parser | |
def parse_type_prefix(self) -> TypeAst: | |
if self._opt_eat('*'): | |
other = self.parse_type_prefix() | |
return TypeAstPtr(self._span, other) | |
return self.parse_type_atom() | |
# @_parser XXX: don't forget | |
def parse_type(self) -> TypeAst: | |
return self.parse_type_prefix() | |
@_parser | |
def parse_atom(self) -> ExprAst: | |
if t := self._opt_eat(TokenKind.INT): | |
return ExprAstInt(self._span, typing.cast(int, t.value)) | |
if t := self._opt_eat(TokenKind.STR): | |
return ExprAstStr(self._span, typing.cast(str, t.value)) | |
if t := self._opt_eat(TokenKind.ID): | |
if t.value == "true": return ExprAstBool(self._span, True) | |
if t.value == "false": return ExprAstBool(self._span, False) | |
n = TypeAstPath(self._span, None, [typing.cast(str, t.value)]) | |
while self._opt_eat('::'): | |
if self._is('['): # generic instantiation | |
n.span = n.span.clone() # make sure the span doesn't change. | |
self._get() | |
params: list[TypeAst] = [] | |
while not self._is(']'): | |
params.append(self.parse_type()) | |
if not self._opt_eat(','): break | |
self._eat(']') | |
n = TypeAstInst(self._span.clone(), n, params) | |
else: | |
if not isinstance(n, TypeAstPath): | |
n = TypeAstPath(self._span.clone(), n, []) | |
n.names.append(self._eat_name()) | |
n.span = n.span.clone() # make sure it doesn't go anywhere. | |
if self._opt_eat(':'): # constructors | |
self._eat('(') | |
_, fields = self.parse_call_args(allow_positional=False) | |
self._eat(')') | |
return ExprAstCons(self._span, n, fields) | |
if not isinstance(n, TypeAstPath) or n.root is not None: | |
ErrorReporter.report("Expected an expression, not a type.", self._span) | |
return ExprAstError(self._span) | |
return ExprAstId(self._span, n.names) | |
ErrorReporter.report("Expected an atom.", self._peek().span) | |
return ExprAstError(self._span) | |
def parse_call_args( | |
self, *, | |
sep: TokenKind | str = ',', | |
end: TokenKind | str = ')', | |
allow_positional = True | |
) -> tuple[list[ExprAst], list[tuple[str, ExprAst]]]: | |
args: list[ExprAst] = [] | |
named_args: list[tuple[str, ExprAst]] = [] | |
while not self._is(end): | |
had_parens = self._is('(') | |
e = self.parse_expr() | |
if isinstance(e, ExprAstBinOp) and e.op == BinOp.ASSIGN and not had_parens: | |
if not isinstance(e.lhs, ExprAstId) or len(e.lhs.names) != 1: | |
ErrorReporter.report( | |
"Arbitrary assignment expressions are not allowed in function arguments.", | |
e.span | |
) | |
ErrorReporter.report("Surround with parentheses", e.span, kind='note') | |
else: | |
named_args.append((e.lhs.names[0], e.rhs)) | |
else: | |
if len(named_args) > 0: | |
ErrorReporter.report("Positional arguments not allowed after named arguments.", e.span) | |
if not allow_positional: | |
ErrorReporter.report("Positional arguments not allowed here.", e.span) | |
args.append(e) | |
if not self._opt_eat(sep): | |
break | |
return args, named_args | |
@_parser | |
def parse_suffix_expr(self) -> ExprAst: | |
n = self.parse_atom() | |
while self._is('(') or self._is('.'): | |
if self._opt_eat('('): | |
args, named_args = self.parse_call_args() | |
self._eat(')') | |
n = ExprAstCall(self._span.clone(), n, args, named_args) | |
elif self._opt_eat('.'): | |
name = self._eat_name() | |
n = ExprAstItem(self._span.clone(), n, name) | |
else: assert False # unreachable hopefully | |
while self._opt_eat_kw('as'): | |
t = self.parse_type() | |
n = ExprAstCast(self._span.clone(), n, t) | |
return n | |
@_parser | |
def parse_bin_expr(self, min_prec: int = 0) -> ExprAst: | |
lhs = self.parse_suffix_expr() | |
while (opk := self._peek().kind) in Parser.OPS: | |
prec, op = Parser.OPS[typing.cast(str, opk)] | |
if prec <= min_prec: break | |
self._get() | |
rhs = self.parse_bin_expr(prec) | |
lhs = ExprAstBinOp(self._span, lhs, rhs, op) | |
return lhs | |
@_parser | |
def parse_stmt(self) -> tuple[StmtAst, int]: # ast, end[1=; 2=} 0] | |
while self._opt_eat(';'): continue | |
if self._opt_eat_kw('let'): | |
decl = self.parse_decl_opt_type(allow_value=True, opt_type=True) | |
self._eat(';') | |
return StmtDeclAst(self._span, decl), 1 | |
e = self.parse_expr() | |
had_semicolon = 0 | |
if isinstance(e, ExprAstIf): had_semicolon = 2 | |
if self._opt_eat(';') is not None: had_semicolon = 1 | |
return StmtExprAst(e.span, e), had_semicolon | |
@_parser | |
def parse_block(self) -> ExprAstBlock: | |
self._eat('{') | |
stmts: list[StmtAst] = [] | |
last, end = self.parse_stmt() | |
while end: | |
if self._is('}'): | |
if end == 1: # { ... x; } | |
stmts.append(last) | |
last = None | |
end = None | |
# else: # { ... if ... } } or { ... x } | |
break | |
else: | |
stmts.append(last) | |
last = None | |
last, end = self.parse_stmt() | |
self._eat('}') | |
tail = None | |
if last is not None: | |
if not isinstance(last, StmtExprAst): | |
ErrorReporter.report("Expected an expression, not a statement.", last.span) | |
else: | |
tail = last.expr | |
return ExprAstBlock(self._span, stmts, tail) | |
@_parser | |
def parse_expr(self) -> ExprAst: | |
if self._is('{'): | |
return self.parse_block() | |
if self._opt_eat_kw('if'): | |
cond = self.parse_expr() | |
true = self.parse_block() | |
false = None | |
if self._opt_eat_kw('else'): | |
if self._is_kw('if'): # allow 'else if' | |
false = self.parse_expr() | |
else: | |
false = self.parse_block() | |
return ExprAstIf(self._span, cond, true, false) | |
if self._is(TokenKind.EOF): | |
return ExprAstError(self._span) | |
return self.parse_bin_expr() | |
@_parser | |
def parse_decl_opt_type(self, *, allow_value: bool, opt_type: bool = False) -> AstDecl[TypeAst | None]: | |
name = self._eat_name() | |
ty = None | |
if not opt_type or self._is(':'): | |
self._eat(':') | |
ty = self.parse_type() | |
val = None | |
if self._opt_eat('='): | |
val = self.parse_expr() | |
if not allow_value: | |
ErrorReporter.report("Value not allowed here.", val.span) | |
val = None | |
return AstDecl[TypeAst | None](self._span, [], name, ty, val) | |
@_parser | |
def parse_decl(self, *, allow_value: bool) -> AstDecl[TypeAst]: | |
r = self.parse_decl_opt_type(allow_value=allow_value, opt_type=False) | |
return typing.cast(AstDecl[TypeAst], r) | |
# NOT a @_parser | |
def parse_func(self, attrs: set[str]) -> AstFunc | None: | |
name = self._eat_name() | |
args: list[AstDecl[TypeAst]] = [] | |
c_varargs = False | |
self._eat('(') | |
while not self._opt_eat(')'): | |
if self._opt_eat_kw("c_varargs"): | |
c_varargs = True | |
self._eat(')') | |
break | |
args.append(self.parse_decl(allow_value=False)) | |
if not self._opt_eat(','): | |
self._eat(')') | |
break | |
ret = None | |
if self._opt_eat(':'): | |
ret = self.parse_type() | |
if self._is('{'): | |
body = self.parse_expr() | |
elif self._opt_eat(';'): | |
body = None | |
else: | |
self._eat('=') | |
body = self.parse_expr() | |
self._eat(';') | |
return AstFunc( | |
span = self._span, | |
attrs = attrs, | |
name = name, | |
args = args, | |
ret = ret, | |
body = body, | |
c_varargs = c_varargs | |
) | |
def parse_using(self): | |
return None | |
def parse_type_params( | |
self, | |
sep: TokenKind | str = ',', | |
end: TokenKind | str = ']' | |
) -> list[TypeAstParam]: | |
params: list[TypeAstParam] = [] | |
while not self._is(end): | |
s = self._peek().span.clone() | |
name = self._eat_name() | |
params.append(TypeAstParam(s, name)) | |
if not self._opt_eat(sep): break | |
return params | |
@_parser | |
def parse_struct(self) -> AstStruct: | |
name = self._eat_name() | |
ty_params = None | |
if self._opt_eat('['): | |
ty_params = self.parse_type_params() | |
self._eat(']') | |
self._eat('{') | |
fields: list[AstDecl[TypeAst]] = [] | |
while not self._is('}'): | |
decl = self.parse_decl(allow_value=False) | |
fields.append(decl) | |
if self._is('}'): break | |
self._eat(',') | |
self._eat('}') | |
return AstStruct( | |
self._span, | |
name = name, | |
ty_params = ty_params, | |
ty = TypeAstStruct( | |
self._span, | |
attrs = set(), | |
fields = fields | |
) | |
) | |
@_parser | |
def parse_toplevel(self) -> AstFunc | AstStruct | None: | |
if self._opt_eat_kw("using"): | |
return self.parse_using() | |
if self._opt_eat_kw("struct"): | |
return self.parse_struct() | |
attrs = set[str]() | |
if self._opt_eat_kw("export"): attrs.add("export") | |
if self._opt_eat_kw("extern"): attrs.add("extern") | |
if self._opt_eat_kw("func"): | |
return self.parse_func(attrs) | |
ErrorReporter.report("Expected a top-level item.", self._peek().span) | |
@dataclasses.dataclass | |
class Rst: | |
span: SrcSpan | |
@dataclasses.dataclass | |
class ExprRst(Rst): | |
ty: LType | |
@dataclasses.dataclass | |
class ExprRstInt(ExprRst): | |
value: int | |
@dataclasses.dataclass | |
class ExprRstStr(ExprRst): | |
value: bytes | |
@dataclasses.dataclass | |
class ExprRstAssign(ExprRst): | |
var: LVar | |
val: ExprRst | |
@dataclasses.dataclass | |
class ExprRstBinOp(ExprRst): | |
lhs: ExprRst | |
rhs: ExprRst | |
op: BinOp | |
op_ty: LType | |
@dataclasses.dataclass | |
class ExprRstUnOp(ExprRst): | |
val: ExprRst | |
op: UnOp | |
@dataclasses.dataclass | |
class ExprRstVar(ExprRst): | |
var: LVar | |
@dataclasses.dataclass | |
class ExprRstIf(ExprRst): | |
cond: ExprRst | |
true: ExprRst | |
false: ExprRst | None | |
@dataclasses.dataclass | |
class ExprRstBlock(ExprRst): | |
stmts: list[StmtRst] | |
tail: ExprRst | None | |
@dataclasses.dataclass | |
class ExprRstCall(ExprRst): | |
func: ExprRst | |
args: list[ExprRst] | |
@dataclasses.dataclass | |
class ExprRstCons(ExprRst): | |
ty: LTypeStruct | |
fields: list[ExprRst] | |
@dataclasses.dataclass | |
class ExprRstCast(ExprRst): | |
expr: ExprRst | |
@dataclasses.dataclass | |
class StmtRst(Rst): | |
pass | |
@dataclasses.dataclass | |
class StmtExprRst(StmtRst): | |
expr: ExprRst | |
@dataclasses.dataclass | |
class StmtDeclRst(StmtRst): | |
ty: LType | |
name: str | |
id: int | |
val: ExprRst | None | |
class Scope(abc.ABC): | |
@abc.abstractmethod | |
def get_var(self, name: str, span: SrcSpan | None) -> LVar: | |
... | |
@abc.abstractmethod | |
def has_type(self, name: str) -> bool: | |
... | |
@abc.abstractmethod | |
def get_type(self, name: str, span: SrcSpan | None) -> LType: | |
... | |
@dataclasses.dataclass | |
class LType(abc.ABC): | |
@abc.abstractmethod | |
def size(self) -> int: ... | |
@abc.abstractmethod | |
def align(self) -> int: ... | |
@abc.abstractmethod | |
def mangled_name(self) -> str: ... | |
@dataclasses.dataclass | |
class LTypeError(LType): | |
def __str__(self) -> str: | |
return "<error>" | |
def mangled_name(self) -> str: | |
return "E" | |
def align(self) -> int: | |
return 1 | |
def size(self) -> int: | |
return 1 | |
@dataclasses.dataclass | |
class LTypeOpaque(LType): | |
name: str | |
def __str__(self) -> str: | |
return f"{self.name}" | |
def mangled_name(self) -> str: | |
return f"O{len(self.name)}{self.name}" | |
def align(self) -> int: | |
return -1 | |
def size(self) -> int: | |
return -1 | |
@dataclasses.dataclass | |
class LTypeInt(LType): | |
VALID_BITS: typing.ClassVar[set[int]] = {8, 16, 32, 64} | |
bits: int | |
signed: bool | |
def __post_init__(self): | |
# assert self.bits in LTypeInt.VALID_BITS | |
pass | |
def extend_to(self, other: LTypeInt) -> None | LTypeInt: | |
if other.bits > self.bits: return other | |
if other.bits < self.bits: return None | |
if self.signed == other.signed: return None | |
if not other.signed: return other | |
return None | |
def __str__(self) -> str: | |
return ("u" if not self.signed else "") + f"int{self.bits}" | |
def mangled_name(self) -> str: | |
return f"I{self.bits // 8}{"s" if self.signed else "u"}" | |
def size(self) -> int: return self.bits // 8 | |
def align(self) -> int: return max(4, self.size()) | |
@dataclasses.dataclass | |
class LTypeBool(LTypeInt): | |
def __init__(self): | |
super().__init__(1, False) | |
def __str__(self) -> str: | |
return "bool" | |
def mangled_name(self) -> str: | |
return "B" | |
def size(self) -> int: return 1 | |
def align(self) -> int: return 1 | |
@dataclasses.dataclass | |
class LTypeVoid(LType): | |
def __str__(self) -> str: | |
return "void" | |
def mangled_name(self) -> str: | |
return "V" | |
def size(self) -> int: return 0 | |
def align(self) -> int: return 1 | |
@dataclasses.dataclass | |
class LTypePtr(LType): | |
to: LType | |
def __str__(self) -> str: | |
return f"*{str(self.to)}" | |
def mangled_name(self) -> str: | |
o = self.to.mangled_name() | |
return f"P{len(o)}{o}" | |
def size(self) -> int: return 8 | |
def align(self) -> int: return 8 | |
@dataclasses.dataclass | |
class LTypeFunc(LType): | |
args: list[LType] | |
ret: LType | |
c_varargs: bool | |
def __str__(self) -> str: | |
v = ", c_varargs" if self.c_varargs else "" | |
return f"func({', '.join(map(str, self.args))}{v}) -> {str(self.ret)}" | |
def mangled_name(self) -> str: | |
v = "v" if self.c_varargs else "" | |
r = self.ret.mangled_name() | |
g = (a.mangled_name() for a in self.args) | |
return f"F{v}{len(r)}{r}{len(self.args)}_{''.join(f"{len(a)}{a}" for a in g)}" | |
def size(self) -> int: return 8 | |
def align(self) -> int: return 8 | |
@dataclasses.dataclass | |
class LTypeArray(LType): | |
of: LType | |
len: int | |
def __str__(self) -> str: | |
return f"[{self.len}]{str(self.of)}" | |
def mangled_name(self) -> str: | |
v = self.of.mangled_name() | |
return f"A{self.len}x{len(v)}{v}" | |
def size(self) -> int: return self.of.size() * self.len | |
def align(self) -> int: return self.of.align() | |
@dataclasses.dataclass | |
class LTypeGeneric(LType, Scope): | |
parent: Scope | |
params: list[TypeAstParam] | |
sub: LType | |
def __init__( | |
self, | |
parent: Scope, | |
params: list[TypeAstParam], | |
get_sub: typing.Callable[[typing.Self], LType] | |
): | |
self.parent = parent | |
self.params = params | |
self.sub = get_sub(self) | |
def get_var(self, name: str, span: SrcSpan | None) -> LVar: | |
return self.parent.get_var(name, span) | |
def has_type(self, name: str) -> bool: | |
for param in self.params: | |
if param.name == name: | |
return True | |
return self.parent.has_type(name) | |
def get_type(self, name: str, span: SrcSpan | None) -> LType: | |
for param in self.params: | |
if param.name == name: | |
return LTypeOpaque(param.name) | |
return self.parent.get_type(name, span) | |
def __str__(self) -> str: | |
return f"[{', '.join(f'{f.name}' for f in self.params)}] {self.sub}" | |
def mangled_name(self) -> str: | |
raise NotImplementedError("Generic types can't be mangled.") | |
def size(self) -> int: return -1 | |
def align(self) -> int: return -1 | |
@dataclasses.dataclass | |
class LTypeStruct(LType, Scope): | |
@dataclasses.dataclass | |
class Field: | |
name: str | |
index: int | |
offset: int | |
ty: LType | |
name: str | |
fields: list[Field] | |
types: dict[str, LType] | |
parent: Scope | None | |
_size: int = 0 | |
_params: list[LType] | None = None | |
def __init__( | |
self, | |
name: str, | |
fields: list[typing.Callable[[typing.Self], tuple[str, LType]]], | |
parent: Scope | None, | |
params: list[LType] | None = None | |
): | |
self.name = name | |
self.types = {} | |
self.parent = parent | |
self.fields = [] | |
self._size = 0 | |
self._params = params | |
for i, field in enumerate(fields): | |
name, ty = field(self) | |
self.fields.append(LTypeStruct.Field( | |
name = name, | |
index = i, | |
offset = self._size, | |
ty = ty | |
)) | |
self._size = (self._size + (ty.align() - 1)) & -ty.align() | |
self._size += ty.size() | |
def __str__(self) -> str: | |
if self.name is not None: return f"struct {self.name}" | |
return f"struct {{ {', '.join(f'{f.name}: {f.ty}' for f in self.fields)} }}" | |
def mangled_name(self) -> str: | |
n = self.name | |
x = "" | |
if n.startswith('!'): | |
n = n[1:] | |
x = "x" | |
if self._params is not None: | |
ps = "".join(f"{len(t)}{t}" for t in (t.mangled_name() for t in self._params)) | |
n = self.name[:self.name.index('[')] | |
return f"St{x}{len(n)}{n}{len(self._params)}_{ps}" | |
return f"S{x}{len(n)}{n}" | |
def get_field(self, name: str, span: SrcSpan | None) -> Field | None: | |
for field in self.fields: | |
if field.name == name: | |
return field | |
ErrorReporter.report(f"No field '{name}' in {str(self)}.", span) | |
return None | |
def get_var(self, of: ExprRst, name: str, span: SrcSpan | None) -> LVar: | |
for field in self.fields: | |
if field.name == name: | |
return LStructFieldVar(field.ty, field.name, of, field.offset, field.index) | |
if self.parent is not None: return self.parent.get_var(name, span) | |
ErrorReporter.report(f"No '{name}' in {str(self)}.", span) | |
return LErrorVar(LTypeError(), name) | |
def has_type(self, name: str) -> bool: | |
return name in self.types | |
def get_type(self, name: str, span: SrcSpan | None) -> LType: | |
if name in self.types: | |
return self.types[name] | |
if self.parent is not None: return self.parent.get_type(name, span) | |
ErrorReporter.report(f"No type {name} in {str(self)}.", span) | |
return LTypeError() | |
def size(self) -> int: return self._size | |
def align(self) -> int: return max(f.ty.align() for f in self.fields) | |
@dataclasses.dataclass | |
class LVar: | |
ty: LType | |
name: str | |
@dataclasses.dataclass | |
class LErrorVar(LVar): | |
pass | |
@dataclasses.dataclass | |
class LLocalVar(LVar): | |
id: int | |
@dataclasses.dataclass | |
class LArgVar(LVar): | |
pass | |
@dataclasses.dataclass | |
class LGlobalVar(LVar): | |
pass | |
@dataclasses.dataclass | |
class LStructFieldVar(LVar): | |
of: ExprRst | |
offset: int | |
index: int | |
@dataclasses.dataclass | |
class LUnionFieldVar(LVar): | |
pass | |
@dataclasses.dataclass | |
class LPointerDerefVar(LVar): | |
pointer: ExprRst | |
@dataclasses.dataclass | |
class FuncScope(Scope, abc.ABC): | |
'''For scopes that are only inside of functions.''' | |
func: LFunc | |
class LocalScope(FuncScope, abc.ABC): | |
@abc.abstractmethod | |
def add_local(self, ty: LType, name: str, span: SrcSpan | None) -> LLocalVar: | |
... | |
@dataclasses.dataclass | |
class BlockScope(LocalScope): | |
parent: Scope | |
vars: dict[str, LLocalVar] | |
types: dict[str, LType] | |
def get_var(self, name: str, span: SrcSpan | None) -> LVar: | |
if name in self.vars: | |
return self.vars[name] | |
if self.parent is not None: | |
return self.parent.get_var(name, span) | |
ErrorReporter.report(f"No such variable: '{name}'", span) | |
return LErrorVar(LTypeError(), name) | |
def has_type(self, name: str) -> bool: | |
if name in self.types: return True | |
if self.parent is not None: | |
return self.parent.has_type(name) | |
return False | |
def get_type(self, name: str, span: SrcSpan | None) -> LType: | |
if name in self.types: | |
return self.types[name] | |
if self.parent is not None: | |
return self.parent.get_type(name, span) | |
ErrorReporter.report(f"No type {name} in {str(self)}", span) | |
return LTypeError() | |
def add_local(self, ty: LType, name: str, span: SrcSpan | None) -> LLocalVar: | |
if name in self.vars: | |
self.vars[name].name += " shadowed" | |
v = self.func.add_local(ty, name, span) | |
self.vars[name] = v | |
return v | |
@dataclasses.dataclass | |
class LFunc(LVar, FuncScope): | |
module: ModuleScope | |
span: SrcSpan | |
arg_names: list[str] | |
attrs: set[str] | |
body: ExprRst | None | |
locls: list[LLocalVar] | |
next_local_id = 0 | |
types: dict[str, LType] | |
ty: LTypeFunc | |
def __init__( | |
self, | |
module: ModuleScope, | |
span: SrcSpan, | |
name: str, | |
ret_ty: LType | None, | |
arg_types: list[LType], | |
c_varargs: bool, | |
arg_names: list[str], | |
attrs: set[str], | |
body: ExprAst | None, | |
resolver: Resolver | |
): | |
self.func = self | |
self.name = name | |
self.span = span | |
self.module = module | |
self.arg_names = arg_names | |
self.attrs = attrs | |
self.code = [] | |
self.locls = [] | |
self.types = {} | |
# this might not be the full type! | |
self.ty = LTypeFunc(arg_types, LTypeError() if ret_ty is None else ret_ty, c_varargs) | |
# we do this stuff in the constructor because we need `self`. | |
rbody = resolver.resolve_expr(self, body, ret_ty) if body else None | |
self.body = rbody | |
if rbody is None: | |
if ret_ty is None: | |
ErrorReporter.report("Cannot infer return type of function with no body.", span) | |
rret_ty = LTypeError() | |
else: | |
rret_ty = ret_ty | |
else: | |
rret_ty = rbody.ty | |
if ret_ty is not None and rret_ty != ret_ty: | |
ErrorReporter.report( | |
f"Mismatching return types. Expected {ret_ty}, got {rret_ty}.", | |
span | |
) | |
# Ensure the type is correct now. | |
self.ty = LTypeFunc(arg_types, rret_ty, c_varargs) | |
def get_var(self, name: str, span: SrcSpan | None) -> LVar: | |
if name == self.name: | |
return self | |
if name in self.arg_names: | |
i = self.arg_names.index(name) | |
return LArgVar(self.ty.args[i], name) | |
return self.module.get_var(name, span) | |
def has_type(self, name: str) -> bool: | |
if name in self.types: return True | |
return self.module.has_type(name) | |
def get_type(self, name: str, span: SrcSpan | None) -> LType: | |
if name in self.types: | |
return self.types[name] | |
return self.module.get_type(name, span) | |
def add_local(self, ty: LType, name: str, span: SrcSpan | None) -> LLocalVar: | |
var = LLocalVar(ty, name, self.next_local_id) | |
self.locls.append(var) | |
self.next_local_id += 1 | |
return var | |
class ModuleScope(Scope): | |
def __init__(self, name: str): | |
self.name = name | |
self.globs: dict[str, LGlobalVar] = {} | |
self.funcs: dict[str, LFunc] = {} | |
self.types: dict[str, LType] = {} | |
def get_var(self, name: str, span: SrcSpan | None) -> LVar: | |
if name in self.globs: | |
return self.globs[name] | |
elif name in self.funcs: | |
return self.funcs[name] | |
ErrorReporter.report(f"No variable '{name}' in module.", span) | |
return LErrorVar(LTypeError(), name) | |
def has_type(self, name: str) -> bool: | |
return name in self.types | |
def get_type(self, name: str, span: SrcSpan | None) -> LType: | |
if name in self.types: | |
return self.types[name] | |
ErrorReporter.report(f"No type '{name}' in module.", span) | |
return LTypeError() | |
def __str__(self) -> str: | |
r = f"module {self.name} {{\n" | |
for v in self.globs.values(): | |
r += f" {v.name}: {v.ty}\n" | |
for v in self.funcs.values(): | |
r += f" {v}\n" | |
for name, v in self.types.items(): | |
r += f" {name} = {v}\n" | |
r += "}" | |
return r | |
class Resolver: | |
BOOL_OPS = { | |
BinOp.EQ, | |
BinOp.NE, | |
BinOp.LT, | |
BinOp.GT, | |
BinOp.LE, | |
BinOp.GE, | |
} | |
def instantiate_generic(self, span: SrcSpan, scope: Scope, gen: LType, params: list[LType]): | |
if not isinstance(gen, LTypeGeneric): | |
ErrorReporter.report(f"Cannot instantiate non-generic type '{gen}'.", span) | |
return LTypeError() | |
if len(params) != len(gen.params): | |
ErrorReporter.report( | |
f"Mismatching number of type parameters. Expected {len(gen.params)}, got {len(params)}.", | |
span | |
) | |
return LTypeError() | |
tys: list[LType] = [] | |
m: dict[str, int] = {} | |
for i, (t, p) in enumerate(zip(params, gen.params)): | |
tys.append(t) | |
m[p.name] = i | |
return self.substitute_types(gen.sub, m, tys)[0] | |
def resolve_type(self, scope: Scope, t: TypeAst) -> LType: | |
match t: | |
case TypeAstPtr(span, to): | |
return LTypePtr(self.resolve_type(scope, to)) | |
case TypeAstInst(span, generic, params): | |
rgen = self.resolve_type(scope, generic) | |
rparams = [self.resolve_type(scope, t) for t in params] | |
return self.instantiate_generic(span, scope, rgen, rparams) | |
case TypeAstPath(span, root, names): | |
assert len(names) > 0 | |
if root is None: | |
n = names[0] | |
match n: | |
case "uint8": return LTypeInt(8, False) | |
case "uint16": return LTypeInt(16, False) | |
case "uint32": return LTypeInt(32, False) | |
case "uint64": return LTypeInt(64, False) | |
case "int8": return LTypeInt(8, True) | |
case "int16": return LTypeInt(16, True) | |
case "int32": return LTypeInt(32, True) | |
case "int64": return LTypeInt(64, True) | |
case "bool": return LTypeBool() | |
case "void": return LTypeVoid() | |
case _: pass | |
current = scope | |
else: | |
current = self.resolve_type(scope, root) | |
for i, name in enumerate(names): | |
if isinstance(current, Scope): | |
current = current.get_type(name, span) | |
elif i != len(names) - 1: | |
ErrorReporter().report(f"Not a scope: {current}.", span) | |
assert isinstance(current, LType) | |
return current | |
case _: | |
raise NotImplementedError(f"resolving {t} is not implemented yet.") | |
def resolve_struct(self, module: ModuleScope, p: AstStruct) -> None: | |
def get_sub(scope: Scope): | |
return LTypeStruct( | |
p.name, | |
[(lambda t, f=field: (f.name, self.resolve_type(t, f.ty))) for field in p.ty.fields], | |
scope | |
) | |
if p.ty_params is not None: | |
module.types[p.name] = LTypeGeneric(module, p.ty_params, get_sub) | |
else: | |
module.types[p.name] = get_sub(module) | |
def resolve_func(self, mod: ModuleScope, p: AstFunc) -> None: | |
mod.funcs[p.name] = LFunc( | |
module = mod, | |
name = p.name, | |
span = p.span, | |
ret_ty = self.resolve_type(mod, p.ret) if p.ret is not None else None, | |
arg_types = [self.resolve_type(mod, a.ty) for a in p.args], | |
c_varargs = p.c_varargs, | |
arg_names = [a.name for a in p.args], | |
attrs = p.attrs, | |
body = p.body, | |
resolver = self | |
) | |
def resolve_stmt(self, scope: LocalScope, e: StmtAst) -> StmtRst: | |
match e: | |
case StmtDeclAst(span, decl): | |
t = self.resolve_type(scope, decl.ty) if decl.ty is not None else None | |
v = self.resolve_expr(scope, decl.value, t) if decl.value is not None else None | |
if decl.ty is not None: | |
assert t is not None | |
elif v is not None: | |
t = v.ty | |
else: | |
ErrorReporter.report("Cannot infer variable type.", span) | |
t = LTypeError() | |
var = scope.add_local(t, decl.name, span) | |
return StmtDeclRst( | |
span, | |
name = decl.name, | |
ty = t, | |
id = var.id, | |
val = v | |
) | |
case StmtExprAst(span, expr): | |
return StmtExprRst(span, self.resolve_expr(scope, expr, None)) | |
case _: | |
raise NotImplementedError(f"Resolving {e} is not implemented yet.") | |
def can_implicit_cast(self, a: LType, b: LType) -> str | None: | |
if isinstance(a, LTypeInt) and isinstance(b, LTypeInt): | |
if a.bits > b.bits: | |
return f"Cannot implicitly truncate from {a} to {b}. Use an explicit `as` cast." | |
if b.signed and not a.signed: | |
return f"Cannot implicitly cast from unsigned {a} to signed {b}. Use an explicit `as` cast." | |
return None | |
return f"Mismatching types. Expected {b}, got {a}." | |
def resolve_expr(self, scope: FuncScope, e: ExprAst, expected_type: LType | None) -> ExprRst: | |
r = self._resolve_expr(scope, e, expected_type) | |
if expected_type is not None: | |
if r.ty == expected_type: | |
return r | |
if (err := self.can_implicit_cast(r.ty, expected_type)) is not None: | |
ErrorReporter.report(err, r.span, kind='note') | |
else: | |
return ExprRstCast(r.span, expected_type, r) | |
return r | |
def _resolve_expr(self, scope: FuncScope, e: ExprAst, expected_type: LType | None) -> ExprRst: | |
match e: | |
case ExprAstInt(span, v): | |
l = v.bit_length() | |
if l > 64: | |
ErrorReporter.report("Integer too large.", span) | |
return ExprRstInt(span, LTypeInt(64, False), v & 0xFFFFFFFFFFFFFFFF) | |
if isinstance(expected_type, LTypeInt): | |
if expected_type.bits <= expected_type.bits: | |
return ExprRstInt(span, expected_type, v) | |
if l > 63: return ExprRstInt(span, LTypeInt(64, False), v) | |
if l > 32: return ExprRstInt(span, LTypeInt(64, True), v) | |
if l > 31: return ExprRstInt(span, LTypeInt(32, False), v) | |
else: return ExprRstInt(span, LTypeInt(32, True), v) | |
case ExprAstStr(span, value): | |
st = scope.get_type('slice', span) | |
st2 = self.instantiate_generic(span, scope, st, [LTypeInt(8, False)]) | |
if not isinstance(st2, LTypeStruct): | |
ErrorReporter.report("slice[uint8] is not a struct.", span) | |
return ExprRst(span, LTypeError()) | |
return ExprRstStr(span, st2, value.encode('utf-8')) | |
case ExprAstCast(span, value, ty): | |
return ExprRstCast(span, self.resolve_type(scope, ty), self.resolve_expr(scope, value, None)) | |
case ExprAstBool(span, v): | |
return ExprRstInt(span, LTypeBool(), v) | |
case ExprAstItem(span, of, name): | |
rof = self.resolve_expr(scope, of, None) | |
if not isinstance(rof.ty, LTypeStruct): | |
ErrorReporter.report(f"Cannot get item '{name}' from non-struct {rof.ty}.", span) | |
return ExprRst(span, LTypeError()) | |
v = rof.ty.get_var(rof, name, span) | |
return ExprRstVar(span, v.ty, v) | |
case ExprAstId(span, names): | |
current = scope | |
for i, name in enumerate(names): | |
if isinstance(current, Scope): | |
current = current.get_var(name, span) | |
elif i != len(names) - 1: | |
ErrorReporter().report(f"Not a scope: {current}.", span) | |
assert isinstance(current, LVar) | |
return ExprRstVar(span, current.ty, current) | |
case ExprAstError(span): return ExprRst(span, LTypeError()) | |
case ExprAstBinOp(span, lhs, rhs, op): | |
rlhs, rrhs = self.resolve_expr(scope, lhs, None), self.resolve_expr(scope, rhs, None) | |
if op == BinOp.ASSIGN: | |
match rlhs: | |
case ExprRstVar(_, ty, var): | |
return ExprRstAssign(span, ty, var, rrhs) | |
case _: | |
ErrorReporter.report("Left operand of assignment is not an lvalue.", rlhs.span) | |
return ExprRst(span, LTypeError()) | |
if not isinstance(rlhs.ty, LTypeInt): | |
ErrorReporter.report("Left operand is not an integer.", rlhs.span) | |
return ExprRst(span, LTypeError()) | |
if not isinstance(rrhs.ty, LTypeInt): | |
ErrorReporter.report("Right operand is not an integer.", rrhs.span) | |
return ExprRst(span, LTypeError()) | |
if ty := rlhs.ty.extend_to(rrhs.ty): pass | |
elif ty := rrhs.ty.extend_to(rlhs.ty): pass | |
else: ty = rlhs.ty | |
if op in Resolver.BOOL_OPS: | |
rty = LTypeBool() | |
else: | |
rty = ty | |
return ExprRstBinOp(span, rty, rlhs, rrhs, op, ty) | |
case ExprAstUnOp(span, val, op): | |
re = self.resolve_expr(scope, val, None) | |
return ExprRstUnOp( | |
span, | |
re.ty, | |
re, | |
op, | |
) | |
case ExprAstIf(span, cond, true, false): | |
rcond = self.resolve_expr(scope, cond, LTypeBool()) | |
if not isinstance(rcond.ty, LTypeBool): | |
ErrorReporter.report(f"Condition must be of type bool, got {rcond.ty}.", rcond.span) | |
rtrue = self.resolve_expr(scope, true, expected_type) | |
if false is None: | |
if not isinstance(rtrue.ty, LTypeVoid): | |
ErrorReporter.report( | |
"The if expression doesn't have an else branch, " | |
"so the return type must be void.", | |
span | |
) | |
ty = LTypeVoid() | |
rfalse = None | |
else: | |
rfalse = self.resolve_expr(scope, false, rtrue.ty) | |
if rfalse.ty != rtrue.ty: | |
ErrorReporter.report( | |
"Mismatching branch types. " | |
f"True branch is {rtrue.ty}, but false branch is {rfalse.ty}", | |
span | |
) | |
ty = rtrue.ty | |
return ExprRstIf( | |
span, | |
ty, | |
rcond, | |
rtrue, | |
rfalse | |
) | |
case ExprAstBlock(span, stmts, tail): | |
nscope = BlockScope(scope.func, scope, {}, {}) | |
rstmts = [self.resolve_stmt(nscope, s) for s in stmts] | |
rtail = self.resolve_expr(nscope, tail, expected_type) if tail is not None else None | |
return ExprRstBlock( | |
span, | |
rtail.ty if rtail is not None else LTypeVoid(), | |
stmts = rstmts, | |
tail = rtail | |
) | |
case ExprAstCons(span, ty, fields): | |
rty = self.resolve_type(scope, ty) | |
if not isinstance(rty, LTypeStruct): | |
ErrorReporter.report(f"Cannot construct non-struct type {rty}.", span) | |
return ExprRst(span, LTypeError()) | |
rfields: list[ExprRst | None] = [None] * len(rty.fields) | |
for name, value in fields: | |
fld = rty.get_field(name, span) | |
if fld is None: | |
continue | |
rval = self.resolve_expr(scope, value, fld.ty) | |
if rval.ty != fld.ty: | |
ErrorReporter.report( | |
f"Incorrect type for field '{fld.name}'. Expected {fld.ty}, got {rval.ty}.", | |
rval.span | |
) | |
rfields[fld.index] = rval | |
if None in rfields: | |
ErrorReporter.report("Not all fields initialized.", span) | |
return ExprRst(span, LTypeError()) | |
return ExprRstCons(span, rty, typing.cast(list[ExprRst], rfields)) | |
case ExprAstCall(span, func, args, named_args): | |
if len(named_args) > 0: | |
ErrorReporter.report("Named arguements are not yet supported.", span) | |
return ExprRst(span, LTypeError()) | |
rfunc = self.resolve_expr(scope, func, None) | |
ret_type = LTypeError() | |
if not isinstance(rfunc.ty, LTypeFunc): | |
ErrorReporter.report(f"Cannot call non-function {rfunc.ty}.", span) | |
rargs = [self.resolve_expr(scope, arg, None) for arg in args] | |
else: | |
if rfunc.ty.c_varargs: | |
if len(args) < len(rfunc.ty.args): | |
ErrorReporter.report( | |
f"Not enough arguments provided. Expected at least {len(rfunc.ty.args)}, got {len(args)}.", | |
span | |
) | |
rargs = [self.resolve_expr(scope, arg, t) for t, arg in zip(rfunc.ty.args, args)] | |
for arg in args[len(rfunc.ty.args):]: | |
rargs.append(self.resolve_expr(scope, arg, None)) | |
else: | |
if len(args) != len(rfunc.ty.args): | |
ErrorReporter.report( | |
f"Mismatching number of arguments. Expected {len(rfunc.ty.args)}, got {len(args)}.", | |
span | |
) | |
rargs = [self.resolve_expr(scope, arg, t) for t, arg in zip(rfunc.ty.args, args)] | |
ret_type = rfunc.ty.ret | |
for i, (narg, argt) in enumerate(zip(rargs, rfunc.ty.args)): | |
if narg.ty != argt: # TODO: compatible with | |
ErrorReporter.report( | |
f"Incorrect type of argument #{i + 1}. Expected {argt}, got {narg.ty}.", | |
narg.span | |
) | |
return ExprRstCall( | |
span, | |
ret_type, | |
func = rfunc, | |
args = rargs | |
) | |
case _: | |
raise NotImplementedError(f"resolving {e} is not implemented yet.") | |
def substitute_types(self, orig: LType, m: dict[str, int], tys: list[LType]) -> tuple[LType, bool]: | |
match orig: | |
case LTypeOpaque(name) if name in m: | |
return tys[m[name]], True | |
case LTypePtr(to): | |
t, r = self.substitute_types(to, m, tys) | |
return LTypePtr(t), r | |
case LTypeArray(of, l): | |
t, r = self.substitute_types(of, m, tys) | |
return LTypeArray(t, l), r | |
case LTypeStruct(): | |
fs: list[typing.Callable[[LTypeStruct], tuple[str, LType]]] = [] | |
did_sub = False | |
for field in orig.fields: | |
t, r = self.substitute_types(field.ty, m, tys) | |
did_sub = did_sub or r | |
fs.append(lambda _, t=t, f=field: (f.name, t)) | |
if not did_sub: return orig, False | |
s = LTypeStruct( | |
f"{orig.name}[{", ".join(map(str, tys))}]", | |
fs, | |
orig.parent, | |
tys | |
), True | |
return s | |
case _: return orig, False | |
class Llvm: | |
pass | |
@dataclasses.dataclass | |
class LlvmInt(Llvm): | |
val: int | |
def __str__(self): return f"{self.val}" | |
@dataclasses.dataclass | |
class LlvmVoid(Llvm): | |
def __str__(self): return "void" | |
@dataclasses.dataclass | |
class LlvmVal(Llvm): | |
val: int | |
def __str__(self): return f"%t{self.val}" | |
@dataclasses.dataclass | |
class LlvmNamedVal(Llvm): | |
name: str | |
def __str__(self): return f"%_n_{self.name}" | |
@dataclasses.dataclass | |
class LlvmLocalVarVal(Llvm): | |
id: int | |
def __str__(self): return f"%_loc_{self.id}" | |
@dataclasses.dataclass | |
class LlvmPrivateGlobal(Llvm): | |
id: int | |
def __str__(self): return f"@.{self.id}" | |
@dataclasses.dataclass | |
class LlvmGlobal(Llvm): | |
name: str | |
def __str__(self): return f"@{self.name}" | |
class CodeGen: | |
def __init__(self, is_le: bool): | |
self._output = bytearray() | |
self._indent_level = 0 | |
self._label_no = 0 | |
self._val_no = 0 | |
self._is_le = is_le | |
self._struct_tys: set[str] = set() | |
def _write_pre(self, s: str): | |
self._output = s.encode() + b"\n" + self._output | |
def _write(self, s: str, indent: int = 0): | |
self._output += b" " * (self._indent_level + indent) | |
self._output += s.encode() | |
self._output += b"\n" | |
def _newpriv_global(self) -> LlvmPrivateGlobal: | |
v = LlvmPrivateGlobal(self._val_no) | |
self._val_no += 1 | |
return v | |
def _newval(self) -> LlvmVal: | |
v = LlvmVal(self._val_no) | |
self._val_no += 1 | |
return v | |
def _newlab(self) -> str: | |
l = f"l{self._label_no}" | |
self._label_no += 1 | |
return l | |
def _label(self, name: str) -> str: | |
self._write(name + ':', -1) | |
return name | |
def _newvaleq( | |
self, | |
instr: str, | |
*args: Llvm | int | str, | |
surround = ' ', | |
skip_assign = False | |
) -> Llvm: | |
v = self._newval() if not skip_assign else LlvmVoid() | |
a = f"{v} = " if not skip_assign else "" | |
self._write(f"{a}{instr}{surround[0:1]}{', '.join(map(str, args))}{surround[1:2]}") | |
return v | |
def _indent(self): self._indent_level += 1 | |
def _dedent(self): self._indent_level -= 1 | |
def _ir_struct_ty_lit(self, t: LTypeStruct) -> str: | |
return f"{{ {', '.join(self._ir_ty(f.ty) for f in t.fields)} }}" | |
def _ir_struct_ty(self, t: LTypeStruct) -> str: | |
name = t.mangled_name() | |
if name not in self._struct_tys: | |
self._write_pre(f"%struct.{name} = type {self._ir_struct_ty_lit(t)}") | |
self._struct_tys.add(name) | |
return f"%struct.{name}" | |
def _ir_func_ty(self, t: LTypeFunc, name = "") -> str: | |
if name: name = " " + name + " " | |
va = ", ..." if t.c_varargs else "" | |
return f"{self._ir_ty(t.ret)}{name}({", ".join(map(self._ir_ty, t.args))}{va})" | |
def _ir_ty(self, t: LType) -> str: | |
match t: | |
case LTypeArray(of, len): return f"[{of} x {len}]" | |
case LTypePtr(): return "ptr" | |
case LTypeFunc(): return self._ir_func_ty(t) | |
case LTypeInt(bits): return f"i{bits}" | |
case LTypeStruct(): return self._ir_struct_ty(t) | |
case LTypeVoid(): return "void" | |
case _: raise NotImplementedError(f"_ir_ty({t})") | |
def _ir_binop(self, op: BinOp, signed: bool, floating: bool) -> str: | |
'''BinOp to IR instruction name.''' | |
sign = "" if floating else ("s" if signed else "u") | |
tyi = "f" if floating else "" | |
ty = "f" if floating else "i" | |
match op: | |
case BinOp.ADD: return f"{tyi}add" | |
case BinOp.SUB: return f"{tyi}sub" | |
case BinOp.MUL: return f"{tyi}mul" | |
case BinOp.DIV: return f"{tyi}{sign}div" | |
case BinOp.MOD: return f"{tyi}{sign}rem" | |
case BinOp.EQ: return f"{ty}cmp eq" | |
case BinOp.NE: return f"{ty}cmp ne" | |
case BinOp.LT: return f"{ty}cmp {sign}lt" | |
case BinOp.GT: return f"{ty}cmp {sign}gt" | |
case BinOp.LE: return f"{ty}cmp {sign}le" | |
case BinOp.GE: return f"{ty}cmp {sign}ge" | |
case _: raise NotImplementedError(f"_ir_binop({op}, {signed})") | |
def _gen_int_cast(self, a: LTypeInt, b: LTypeInt, v: Llvm) -> Llvm: | |
'''Generate integer conversion instructions. (maybe generate none)''' | |
if a.bits == b.bits: return v | |
if a.bits < b.bits: | |
ins = f"{'s' if b.signed else 'z'}ext {self._ir_ty(a)} {v} to {self._ir_ty(b)}" | |
return self._newvaleq(ins) | |
# a.bits > b.bits | |
return self._newvaleq(f"trunc {self._ir_ty(a)} {v} to {self._ir_ty(b)}") | |
def _gen_store_move(self, ty: LType, val: Llvm, to: Llvm): | |
self._write(f"store {self._ir_ty(ty)} {val}, ptr {to}") | |
def gen_move(self, e: ExprRst, to_ty: LType, to: Llvm): | |
'''Generate move of `e` to address `to`''' | |
match e: | |
case ExprRstInt(_, ty, value): | |
assert isinstance(ty, LTypeInt) | |
self._gen_store_move(ty, LlvmInt(value), to) | |
case ExprRstBlock(_, ty, stmts, tail): | |
for stmt in stmts: self.gen_stmt(stmt) | |
assert tail is not None | |
self.gen_move(tail, to_ty, to) | |
case _: | |
v = self.gen_expr(e) | |
assert e.ty == to_ty | |
self._gen_store_move(e.ty, v, to) | |
def gen_stmt(self, e: StmtRst): | |
match e: | |
case StmtExprRst(_, expr): | |
self.gen_expr(expr) | |
case StmtDeclRst(_, ty, name, id, val): | |
v = LlvmLocalVarVal(id) | |
self._write(f"{v} = alloca {self._ir_ty(ty)} ; {name}: {ty}") | |
if val is not None: | |
self.gen_move(val, ty, v) | |
def _ir_str_const(self, s: bytes) -> str: | |
return "c\"" + "".join(chr(c) if 32 <= c <= 127 else f"\\{c:02X}" for c in s) + "\"" | |
def _gen_init_struct(self, ty: LTypeStruct, vs: typing.Iterable[Llvm | int | str]) -> Llvm: | |
'''Returns pointer to struct.''' | |
sty = self._ir_struct_ty(ty) | |
p = self._newvaleq("alloca", sty, f"align {ty.align()}") | |
cur = self._newvaleq("load", sty, f"ptr {p}") | |
for i, (v, f) in enumerate(zip(vs, ty.fields)): | |
cur = self._newvaleq("insertvalue", f"{sty} {cur}", f"{self._ir_ty(f.ty)} {v}", i) | |
return cur | |
def gen_expr(self, e: ExprRst) -> Llvm: | |
match e: | |
case ExprRstInt(_, ty, value): return LlvmInt(value) | |
case ExprRstStr(_, ty, value): | |
assert isinstance(ty, LTypeStruct) | |
g = self._newpriv_global() | |
self._write_pre( | |
f"{g} = private unnamed_addr constant [{len(value)} x i8] " + | |
self._ir_str_const(value) | |
) | |
return self._gen_init_struct(ty, [len(value), g]) | |
case ExprRstBlock(_, ty, stmts, tail): | |
for stmt in stmts: | |
self.gen_stmt(stmt) | |
return self.gen_expr(tail) if tail else LlvmVoid() | |
case ExprRstVar(_, ty, var): | |
match var: | |
case LArgVar(ty, name): return LlvmNamedVal(f"arg_{name}") | |
case LStructFieldVar(ty, name, of, _, index): | |
v = self.gen_expr(of) | |
return self._newvaleq("extractvalue", f"{self._ir_ty(of.ty)} {v}", index) | |
# case LGlobalVar(ty, name): return LlvmGlobalVal(name) | |
case LLocalVar(ty, name, id): | |
v = self._newval() | |
self._write(f"{v} = load {self._ir_ty(ty)}, ptr {LlvmLocalVarVal(id)}") | |
return v | |
case _: raise NotImplementedError(f"gen_expr(..)/ExprRstVar(_, _, {var})") | |
case ExprRstIf(_, ty, cond, true, false): | |
ret_lab = self._newlab() | |
true_lab = self._newlab() | |
false_lab = self._newlab() if false is not None else ret_lab | |
cond_v = self.gen_expr(cond) | |
# cond_v_i1 = self._newvaleq(f"trunc {self._ir_ty(cond.type)} {cond_v} to i1") | |
self._write(f"br i1 {cond_v}, label %{true_lab}, label %{false_lab}") | |
self._label(true_lab) | |
true_v = self.gen_expr(true) | |
self._write(f"br label %{ret_lab}") | |
false_v = None | |
if false is not None: | |
self._label(false_lab) | |
false_v = self.gen_expr(false) | |
self._write(f"br label %{ret_lab}") | |
self._label(ret_lab) | |
if not isinstance(ty, LTypeVoid) and false_v is not None: | |
# <=> false is not None | |
return self._newvaleq( | |
f"phi {self._ir_ty(ty)}", | |
f"[{true_v}, %{true_lab}]", | |
f"[{false_v}, %{false_lab}]" | |
) | |
return LlvmVoid() | |
case ExprRstAssign(_, ty, var, val): | |
match var: | |
case LLocalVar(_, name, id): | |
v = LlvmLocalVarVal(id) | |
self.gen_move(val, ty, v) | |
return v | |
case _: raise NotImplementedError(f"gen_expr(..)/ExprRstAssign(_, _, {var}, _)") | |
case ExprRstCast(_, ty, val): | |
rval = self.gen_expr(val) | |
if isinstance(ty, LTypeInt) and isinstance(val.ty, LTypeInt): | |
return self._gen_int_cast(val.ty, ty, rval) | |
else: | |
raise NotImplementedError(f"gen_expr(..)/ExprRstCast(_, {ty}, {val})") | |
case ExprRstBinOp(_, ty, lhs, rhs, op, op_ty): | |
l, r = self.gen_expr(lhs), self.gen_expr(rhs) | |
assert isinstance(ty, LTypeInt) | |
assert isinstance(op_ty, LTypeInt) | |
assert isinstance(lhs.ty, LTypeInt) | |
assert isinstance(rhs.ty, LTypeInt) | |
l2 = self._gen_int_cast(lhs.ty, op_ty, l) | |
r2 = self._gen_int_cast(lhs.ty, op_ty, r) | |
op = self._ir_binop(op, ty.signed, False) | |
v = self._newvaleq(op, f"{self._ir_ty(op_ty)} {l2}", r2) | |
# if op.startswith("icmp"): | |
# return self._newvaleq(f"zext i1 {v} to i32") | |
return v | |
case ExprRstCons(_, ty, fields): | |
return self._gen_init_struct(ty, map(self.gen_expr, fields)) | |
case ExprRstCall(_, ty, func, args): | |
assert isinstance(func, ExprRstVar) | |
fn = func.var | |
assert isinstance(fn, LFunc) | |
if fn.ty.c_varargs: | |
fargs = (f"{self._ir_ty(arg.ty)} {self.gen_expr(arg)}" for arg in args) | |
else: | |
fargs = (f"{self._ir_ty(ty)} {self.gen_expr(arg)}" for arg, ty in zip(args, fn.ty.args)) | |
return self._newvaleq( | |
f"call ccc {self._ir_ty(fn.ty)} @{fn.name}", | |
*fargs, | |
surround="()", | |
skip_assign = fn.ty.ret == LTypeVoid() | |
) | |
case _: raise NotImplementedError(f"gen_expr({e})") | |
def gen_func(self, p: LFunc): | |
s = "" | |
no_body = False | |
if p.body is None: | |
s += "declare " | |
no_body = True | |
else: | |
s += "define " | |
if "export" in p.attrs: s += "external " | |
s += "ccc " | |
s += f"{self._ir_ty(p.ty.ret)} " | |
s += f"@{p.name}(" | |
s += ", ".join(f"{self._ir_ty(a)} %_n_arg_{n}" for n, a in zip(p.arg_names, p.ty.args)) | |
if p.ty.c_varargs: s += ", ..." | |
s += ")" | |
if no_body: | |
self._write(s) | |
return | |
s += " {" | |
self._write(s) | |
self._indent() | |
self._label(self._newlab()) | |
assert p.body is not None | |
v = self.gen_expr(p.body) | |
if not isinstance(v, LlvmVoid): | |
self._write(f"ret {self._ir_ty(p.ty.ret)} {v}") | |
else: | |
self._write("ret void") | |
self._dedent() | |
self._write("}") | |
def replace_in_list[T](l: list[T], a: T, b: T) -> list[T]: | |
'''Replace `a` with `b` in list `l`, mutating it. Return the updated list.''' | |
for i, e in enumerate(l): | |
if e == a: l[i] = b | |
return l | |
def cli_build(args: argparse.Namespace) -> str: | |
in_file: typing.TextIO = args.file | |
inp = InputStr(in_file.read(), fname = in_file.name) | |
in_file.close() | |
inp_id = InputManager.add(inp) | |
toks = list(tokenize(inp_id, inp.s)) | |
par = Parser(toks) | |
mod = ModuleScope("main") | |
gen = CodeGen(True) | |
res = Resolver() | |
while not par._is(TokenKind.EOF): | |
try: | |
n = par.parse_toplevel() | |
if ErrorReporter.report_count > 0: | |
exit(1) | |
if args.debug: | |
pprint.pprint(n) | |
match n: | |
case AstFunc(): | |
res.resolve_func(mod, n) | |
if ErrorReporter.report_count > 0: | |
exit(1) | |
gen.gen_func(mod.funcs[n.name]) | |
case AstStruct(): | |
res.resolve_struct(mod, n) | |
if ErrorReporter.report_count > 0: | |
exit(1) | |
except Parser.FailedMiserably: | |
pass | |
except IrrecoverableError: | |
exit(1) | |
if ErrorReporter.report_count > 0: | |
exit(1) | |
if args.debug: | |
sys.stderr.buffer.write(gen._output) | |
with subprocess.Popen( | |
replace_in_list(shlex.split(args.llc_cmd), '%', str(max(0, min(args.opt_level, 3)))), | |
stdout=subprocess.PIPE, | |
stderr=subprocess.PIPE, | |
stdin=subprocess.PIPE | |
) as n: | |
assert n.stdin is not None | |
llc_output, stderr = n.communicate(gen._output) | |
if n.returncode != 0: | |
ErrorReporter.report("Failed to run LLC", None) | |
sys.stderr.buffer.write(stderr) | |
exit(1) | |
else: | |
sys.stderr.buffer.write(stderr) | |
if args.debug: | |
sys.stderr.buffer.write(llc_output) | |
bin_name = in_file.name.removesuffix(".qq") | |
with subprocess.Popen( | |
replace_in_list(shlex.split(args.asm_cmd), '%', bin_name), | |
stdout=subprocess.PIPE, | |
stderr=subprocess.PIPE, | |
stdin=subprocess.PIPE | |
) as n: | |
assert n.stdin is not None | |
_, stderr = n.communicate(llc_output) | |
if n.returncode != 0: | |
ErrorReporter.report("Failed to run assembler", None) | |
sys.stderr.buffer.write(stderr) | |
exit(1) | |
else: | |
sys.stderr.buffer.write(stderr) | |
if ErrorReporter.report_count > 0: | |
exit(1) | |
ErrorReporter.report("", None, kind='success', title=f"Compiled {bin_name} successfully!") | |
return bin_name | |
def cli_run(args: argparse.Namespace): | |
bin_name = cli_build(args) | |
sys.stdout.flush() | |
sys.stderr.flush() | |
exit(subprocess.run([os.path.abspath(bin_name), *args.rest]).returncode) | |
def cli(): | |
parser = argparse.ArgumentParser(prog='qqc') | |
parser.add_argument('-d', '--debug', action='store_true') | |
subparsers = parser.add_subparsers(help="Commands:", required=True) | |
common_build_parser = argparse.ArgumentParser(add_help=False) | |
common_build_parser.add_argument( | |
'--use-asm', | |
type=str, | |
help="Assembler command to use. '%%' is replaced with the output file.", | |
default="clang -fPIC libqqrt.o -xassembler -fuse-ld=lld -o % -", | |
dest="asm_cmd" | |
) | |
common_build_parser.add_argument( | |
'--use-llc', | |
type=str, | |
help="LLC command to use. '%%' is replaced with the opt level. " | |
"Input is on the stdin, output is read from stdout.", | |
default="llc -relocation-model=pic -O %", | |
dest="llc_cmd" | |
) | |
common_build_parser.add_argument( | |
'-O', | |
type=int, | |
help="Optimization level. 1, 2, or 3. Default is 2.", | |
default=2, | |
dest="opt_level" | |
) | |
common_build_parser.add_argument('file', type=argparse.FileType("r", encoding='utf-8'), help="Input file") | |
build_parser = subparsers.add_parser('build', help="Compile file.", parents=[common_build_parser]) | |
build_parser.set_defaults(func=cli_build) | |
run_parser = subparsers.add_parser('run', help="Compile and run file.", parents=[common_build_parser]) | |
run_parser.set_defaults(func=cli_run) | |
run_parser.add_argument('rest', nargs=argparse.REMAINDER, help="Arguments passed to output program.") | |
args = parser.parse_args() | |
args.func(args) | |
if __name__ == '__main__': cli() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Requirements
Runtime library
Should provide
void qq_write_uint8s(uint64_t fout, struct qq_slice_uint8 i)
void qq_write_int32(uint64_t fout, int32_t i)
int32_t qq_read_int32(uint64_t fout)
int(input())
in python).uint64_t qq_stdout()
uint64_t qq_stderr()
uint64_t qq_stdin()
Possible C code
Example code