Created
June 2, 2022 22:50
-
-
Save tooolbox/eb90c871c7680d66bb1d7dd4d4068cea to your computer and use it in GitHub Desktop.
Modified version of Teal that runs in Algernon
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
local _tl_compat; if (tonumber((_VERSION or ''):match('[%d.]*$')) or 0) < 5.3 then local p, m = pcall(require, 'compat53.module'); if p then _tl_compat = m end end; local assert = _tl_compat and _tl_compat.assert or assert; local io = _tl_compat and _tl_compat.io or io; local ipairs = _tl_compat and _tl_compat.ipairs or ipairs; local load = _tl_compat and _tl_compat.load or load; local math = _tl_compat and _tl_compat.math or math; local os = _tl_compat and _tl_compat.os or os; local package = _tl_compat and _tl_compat.package or package; local pairs = _tl_compat and _tl_compat.pairs or pairs; local string = _tl_compat and _tl_compat.string or string; local table = _tl_compat and _tl_compat.table or table; local _tl_table_unpack = unpack or table.unpack | |
local VERSION = "0.13.1+dev" | |
local tl = {TypeCheckOptions = {}, Env = {}, Symbol = {}, Result = {}, Error = {}, TypeInfo = {}, TypeReport = {}, TypeReportEnv = {}, } | |
tl.version = function() | |
return VERSION | |
end | |
tl.warning_kinds = { | |
["unused"] = true, | |
["redeclaration"] = true, | |
["branch"] = true, | |
["hint"] = true, | |
["debug"] = true, | |
} | |
tl.typecodes = { | |
NIL = 0x00000001, | |
NUMBER = 0x00000002, | |
BOOLEAN = 0x00000004, | |
STRING = 0x00000008, | |
TABLE = 0x00000010, | |
FUNCTION = 0x00000020, | |
USERDATA = 0x00000040, | |
THREAD = 0x00000080, | |
IS_TABLE = 0x00000008, | |
IS_NUMBER = 0x00000002, | |
IS_STRING = 0x00000004, | |
LUA_MASK = 0x00000fff, | |
INTEGER = 0x00010002, | |
ARRAY = 0x00010008, | |
RECORD = 0x00020008, | |
ARRAYRECORD = 0x00030008, | |
MAP = 0x00040008, | |
TUPLE = 0x00080008, | |
EMPTY_TABLE = 0x00000008, | |
ENUM = 0x00010004, | |
IS_ARRAY = 0x00010008, | |
IS_RECORD = 0x00020008, | |
NOMINAL = 0x10000000, | |
TYPE_VARIABLE = 0x08000000, | |
IS_UNION = 0x40000000, | |
IS_POLY = 0x20000020, | |
ANY = 0xffffffff, | |
UNKNOWN = 0x80008000, | |
INVALID = 0x80000000, | |
IS_SPECIAL = 0x80000000, | |
IS_VALID = 0x00000fff, | |
} | |
local Result = tl.Result | |
local Env = tl.Env | |
local Error = tl.Error | |
local CompatMode = tl.CompatMode | |
local TypeCheckOptions = tl.TypeCheckOptions | |
local LoadMode = tl.LoadMode | |
local LoadFunction = tl.LoadFunction | |
local TargetMode = tl.TargetMode | |
local TypeInfo = tl.TypeInfo | |
local TypeReport = tl.TypeReport | |
local TypeReportEnv = tl.TypeReportEnv | |
local Symbol = tl.Symbol | |
local TokenKind = {} | |
local Token = {} | |
do | |
local LexState = {} | |
local last_token_kind = { | |
["identifier"] = "identifier", | |
["got -"] = "op", | |
["got ."] = ".", | |
["got .."] = "op", | |
["got ="] = "op", | |
["got ~"] = "op", | |
["got ["] = "[", | |
["got 0"] = "number", | |
["got <"] = "op", | |
["got >"] = "op", | |
["got /"] = "op", | |
["got :"] = "op", | |
["string single"] = "$invalid_string$", | |
["string single got \\"] = "$invalid_string$", | |
["string double"] = "$invalid_string$", | |
["string double got \\"] = "$invalid_string$", | |
["string long"] = "$invalid_string$", | |
["string long got ]"] = "$invalid_string$", | |
["number dec"] = "integer", | |
["number decfloat"] = "number", | |
["number hex"] = "integer", | |
["number hexfloat"] = "number", | |
["number power"] = "number", | |
["number powersign"] = "$invalid_number$", | |
} | |
local keywords = { | |
["and"] = true, | |
["break"] = true, | |
["do"] = true, | |
["else"] = true, | |
["elseif"] = true, | |
["end"] = true, | |
["false"] = true, | |
["for"] = true, | |
["function"] = true, | |
["goto"] = true, | |
["if"] = true, | |
["in"] = true, | |
["local"] = true, | |
["nil"] = true, | |
["not"] = true, | |
["or"] = true, | |
["repeat"] = true, | |
["return"] = true, | |
["then"] = true, | |
["true"] = true, | |
["until"] = true, | |
["while"] = true, | |
} | |
local lex_any_char_states = { | |
["\""] = "string double", | |
["'"] = "string single", | |
["-"] = "got -", | |
["."] = "got .", | |
["0"] = "got 0", | |
["<"] = "got <", | |
[">"] = "got >", | |
["/"] = "got /", | |
[":"] = "got :", | |
["="] = "got =", | |
["~"] = "got ~", | |
["["] = "got [", | |
} | |
for c = string.byte("a"), string.byte("z") do | |
lex_any_char_states[string.char(c)] = "identifier" | |
end | |
for c = string.byte("A"), string.byte("Z") do | |
lex_any_char_states[string.char(c)] = "identifier" | |
end | |
lex_any_char_states["_"] = "identifier" | |
for c = string.byte("1"), string.byte("9") do | |
lex_any_char_states[string.char(c)] = "number dec" | |
end | |
local lex_word = {} | |
for c = string.byte("a"), string.byte("z") do | |
lex_word[string.char(c)] = true | |
end | |
for c = string.byte("A"), string.byte("Z") do | |
lex_word[string.char(c)] = true | |
end | |
for c = string.byte("0"), string.byte("9") do | |
lex_word[string.char(c)] = true | |
end | |
lex_word["_"] = true | |
local lex_decimals = {} | |
for c = string.byte("0"), string.byte("9") do | |
lex_decimals[string.char(c)] = true | |
end | |
local lex_hexadecimals = {} | |
for c = string.byte("0"), string.byte("9") do | |
lex_hexadecimals[string.char(c)] = true | |
end | |
for c = string.byte("a"), string.byte("f") do | |
lex_hexadecimals[string.char(c)] = true | |
end | |
for c = string.byte("A"), string.byte("F") do | |
lex_hexadecimals[string.char(c)] = true | |
end | |
local lex_any_char_kinds = {} | |
local single_char_kinds = { "[", "]", "(", ")", "{", "}", ",", "#", "`", ";" } | |
for _, c in ipairs(single_char_kinds) do | |
lex_any_char_kinds[c] = c | |
end | |
for _, c in ipairs({ "+", "*", "|", "&", "%", "^" }) do | |
lex_any_char_kinds[c] = "op" | |
end | |
local lex_space = {} | |
for _, c in ipairs({ " ", "\t", "\v", "\n", "\r" }) do | |
lex_space[c] = true | |
end | |
local escapable_characters = { | |
a = true, | |
b = true, | |
f = true, | |
n = true, | |
r = true, | |
t = true, | |
v = true, | |
z = true, | |
["\\"] = true, | |
["\'"] = true, | |
["\""] = true, | |
["\r"] = true, | |
["\n"] = true, | |
} | |
local function lex_string_escape(input, i, c) | |
if escapable_characters[c] then | |
return 0, true | |
elseif c == "x" then | |
return 2, ( | |
lex_hexadecimals[input:sub(i + 1, i + 1)] and | |
lex_hexadecimals[input:sub(i + 2, i + 2)]) | |
elseif c == "u" then | |
if input:sub(i + 1, i + 1) == "{" then | |
local p = i + 2 | |
if not lex_hexadecimals[input:sub(p, p)] then | |
return 2, false | |
end | |
while true do | |
p = p + 1 | |
c = input:sub(p, p) | |
if not lex_hexadecimals[c] then | |
return p - i, c == "}" | |
end | |
end | |
end | |
elseif lex_decimals[c] then | |
local len = lex_decimals[input:sub(i + 1, i + 1)] and | |
(lex_decimals[input:sub(i + 2, i + 2)] and 2 or 1) or | |
0 | |
return len, tonumber(input:sub(i, i + len)) < 256 | |
else | |
return 0, false | |
end | |
end | |
function tl.lex(input) | |
local tokens = {} | |
local state = "any" | |
local fwd = true | |
local y = 1 | |
local x = 0 | |
local i = 0 | |
local lc_open_lvl = 0 | |
local lc_close_lvl = 0 | |
local ls_open_lvl = 0 | |
local ls_close_lvl = 0 | |
local errs = {} | |
local nt = 0 | |
local tx | |
local ty | |
local ti | |
local in_token = false | |
local function begin_token() | |
tx = x | |
ty = y | |
ti = i | |
in_token = true | |
end | |
local function end_token(kind, tk) | |
nt = nt + 1 | |
tokens[nt] = { | |
x = tx, | |
y = ty, | |
i = ti, | |
tk = tk, | |
kind = kind, | |
} | |
in_token = false | |
end | |
local function end_token_identifier() | |
local tk = input:sub(ti, i - 1) | |
nt = nt + 1 | |
tokens[nt] = { | |
x = tx, | |
y = ty, | |
i = ti, | |
tk = tk, | |
kind = keywords[tk] and "keyword" or "identifier", | |
} | |
in_token = false | |
end | |
local function end_token_prev(kind) | |
local tk = input:sub(ti, i - 1) | |
nt = nt + 1 | |
tokens[nt] = { | |
x = tx, | |
y = ty, | |
i = ti, | |
tk = tk, | |
kind = kind, | |
} | |
in_token = false | |
end | |
local function end_token_here(kind) | |
local tk = input:sub(ti, i) | |
nt = nt + 1 | |
tokens[nt] = { | |
x = tx, | |
y = ty, | |
i = ti, | |
tk = tk, | |
kind = kind, | |
} | |
in_token = false | |
end | |
local function drop_token() | |
in_token = false | |
end | |
local len = #input | |
if input:sub(1, 2) == "#!" then | |
i = input:find("\n") | |
if not i then | |
i = len + 1 | |
end | |
y = 2 | |
x = 0 | |
end | |
state = "any" | |
while i <= len do | |
if fwd then | |
i = i + 1 | |
if i > len then | |
break | |
end | |
end | |
local c = input:sub(i, i) | |
if fwd then | |
if c == "\n" then | |
y = y + 1 | |
x = 0 | |
else | |
x = x + 1 | |
end | |
else | |
fwd = true | |
end | |
if state == "any" then | |
local st = lex_any_char_states[c] | |
if st then | |
state = st | |
begin_token() | |
else | |
local k = lex_any_char_kinds[c] | |
if k then | |
begin_token() | |
end_token(k, c) | |
elseif not lex_space[c] then | |
begin_token() | |
end_token_here("$invalid$") | |
table.insert(errs, tokens[#tokens]) | |
end | |
end | |
elseif state == "identifier" then | |
if not lex_word[c] then | |
end_token_identifier() | |
fwd = false | |
state = "any" | |
end | |
elseif state == "string double" then | |
if c == "\\" then | |
state = "string double got \\" | |
elseif c == "\"" then | |
end_token_here("string") | |
state = "any" | |
end | |
elseif state == "comment short" then | |
if c == "\n" then | |
state = "any" | |
end | |
elseif state == "got =" then | |
local t | |
if c == "=" then | |
t = "==" | |
else | |
t = "=" | |
fwd = false | |
end | |
end_token("op", t) | |
state = "any" | |
elseif state == "got ." then | |
if c == "." then | |
state = "got .." | |
elseif lex_decimals[c] then | |
state = "number decfloat" | |
else | |
end_token(".", ".") | |
fwd = false | |
state = "any" | |
end | |
elseif state == "got :" then | |
local t | |
if c == ":" then | |
t = "::" | |
else | |
t = ":" | |
fwd = false | |
end | |
end_token(t, t) | |
state = "any" | |
elseif state == "got [" then | |
if c == "[" then | |
state = "string long" | |
elseif c == "=" then | |
ls_open_lvl = ls_open_lvl + 1 | |
else | |
end_token("[", "[") | |
fwd = false | |
state = "any" | |
ls_open_lvl = 0 | |
end | |
elseif state == "number dec" then | |
if lex_decimals[c] then | |
elseif c == "." then | |
state = "number decfloat" | |
elseif c == "e" or c == "E" then | |
state = "number powersign" | |
else | |
end_token_prev("integer") | |
fwd = false | |
state = "any" | |
end | |
elseif state == "got -" then | |
if c == "-" then | |
state = "got --" | |
else | |
end_token("op", "-") | |
fwd = false | |
state = "any" | |
end | |
elseif state == "got .." then | |
if c == "." then | |
end_token("...", "...") | |
else | |
end_token("op", "..") | |
fwd = false | |
end | |
state = "any" | |
elseif state == "number hex" then | |
if lex_hexadecimals[c] then | |
elseif c == "." then | |
state = "number hexfloat" | |
elseif c == "p" or c == "P" then | |
state = "number powersign" | |
else | |
end_token_prev("integer") | |
fwd = false | |
state = "any" | |
end | |
elseif state == "got --" then | |
if c == "[" then | |
state = "got --[" | |
else | |
fwd = false | |
state = "comment short" | |
drop_token() | |
end | |
elseif state == "got 0" then | |
if c == "x" or c == "X" then | |
state = "number hex" | |
elseif c == "e" or c == "E" then | |
state = "number powersign" | |
elseif lex_decimals[c] then | |
state = "number dec" | |
elseif c == "." then | |
state = "number decfloat" | |
else | |
end_token_prev("integer") | |
fwd = false | |
state = "any" | |
end | |
elseif state == "got --[" then | |
if c == "[" then | |
state = "comment long" | |
elseif c == "=" then | |
lc_open_lvl = lc_open_lvl + 1 | |
else | |
fwd = false | |
state = "comment short" | |
drop_token() | |
lc_open_lvl = 0 | |
end | |
elseif state == "comment long" then | |
if c == "]" then | |
state = "comment long got ]" | |
end | |
elseif state == "comment long got ]" then | |
if c == "]" and lc_close_lvl == lc_open_lvl then | |
drop_token() | |
state = "any" | |
lc_open_lvl = 0 | |
lc_close_lvl = 0 | |
elseif c == "=" then | |
lc_close_lvl = lc_close_lvl + 1 | |
else | |
state = "comment long" | |
lc_close_lvl = 0 | |
end | |
elseif state == "string double got \\" then | |
local skip, valid = lex_string_escape(input, i, c) | |
i = i + skip | |
if not valid then | |
end_token_here("$invalid_string$") | |
table.insert(errs, tokens[#tokens]) | |
end | |
x = x + skip | |
state = "string double" | |
elseif state == "string single" then | |
if c == "\\" then | |
state = "string single got \\" | |
elseif c == "'" then | |
end_token_here("string") | |
state = "any" | |
end | |
elseif state == "string single got \\" then | |
local skip, valid = lex_string_escape(input, i, c) | |
i = i + skip | |
if not valid then | |
end_token_here("$invalid_string$") | |
table.insert(errs, tokens[#tokens]) | |
end | |
x = x + skip | |
state = "string single" | |
elseif state == "got ~" then | |
local t | |
if c == "=" then | |
t = "~=" | |
else | |
t = "~" | |
fwd = false | |
end | |
end_token("op", t) | |
state = "any" | |
elseif state == "got <" then | |
local t | |
if c == "=" then | |
t = "<=" | |
elseif c == "<" then | |
t = "<<" | |
else | |
t = "<" | |
fwd = false | |
end | |
end_token("op", t) | |
state = "any" | |
elseif state == "got >" then | |
local t | |
if c == "=" then | |
t = ">=" | |
elseif c == ">" then | |
t = ">>" | |
else | |
t = ">" | |
fwd = false | |
end | |
end_token("op", t) | |
state = "any" | |
elseif state == "got /" then | |
local t | |
if c == "/" then | |
t = "//" | |
else | |
t = "/" | |
fwd = false | |
end | |
end_token("op", t) | |
state = "any" | |
elseif state == "string long" then | |
if c == "]" then | |
state = "string long got ]" | |
end | |
elseif state == "string long got ]" then | |
if c == "]" then | |
if ls_close_lvl == ls_open_lvl then | |
end_token_here("string") | |
state = "any" | |
ls_open_lvl = 0 | |
ls_close_lvl = 0 | |
end | |
elseif c == "=" then | |
ls_close_lvl = ls_close_lvl + 1 | |
else | |
state = "string long" | |
ls_close_lvl = 0 | |
end | |
elseif state == "number hexfloat" then | |
if c == "p" or c == "P" then | |
state = "number powersign" | |
elseif not lex_hexadecimals[c] then | |
end_token_prev("number") | |
fwd = false | |
state = "any" | |
end | |
elseif state == "number decfloat" then | |
if c == "e" or c == "E" then | |
state = "number powersign" | |
elseif not lex_decimals[c] then | |
end_token_prev("number") | |
fwd = false | |
state = "any" | |
end | |
elseif state == "number powersign" then | |
if c == "-" or c == "+" then | |
state = "number power" | |
elseif lex_decimals[c] then | |
state = "number power" | |
else | |
end_token_here("$invalid_number$") | |
table.insert(errs, tokens[#tokens]) | |
state = "any" | |
end | |
elseif state == "number power" then | |
if not lex_decimals[c] then | |
end_token_prev("number") | |
fwd = false | |
state = "any" | |
end | |
end | |
end | |
if in_token then | |
if last_token_kind[state] then | |
end_token_prev(last_token_kind[state]) | |
if keywords[tokens[nt].tk] then | |
tokens[nt].kind = "keyword" | |
end | |
else | |
drop_token() | |
end | |
end | |
table.insert(tokens, { x = x + 1, y = y, i = i, tk = "$EOF$", kind = "$EOF$" }) | |
return tokens, (#errs > 0) and errs | |
end | |
end | |
local function binary_search(list, item, cmp) | |
local len = #list | |
local mid | |
local s, e = 1, len | |
while s <= e do | |
mid = math.floor((s + e) / 2) | |
local val = list[mid] | |
local res = cmp(val, item) | |
if res then | |
if mid == len then | |
return mid, val | |
else | |
if not cmp(list[mid + 1], item) then | |
return mid, val | |
end | |
end | |
s = mid + 1 | |
else | |
e = mid - 1 | |
end | |
end | |
end | |
function tl.get_token_at(tks, y, x) | |
local _, found = binary_search( | |
tks, nil, | |
function(tk) | |
return tk.y < y or | |
(tk.y == y and tk.x <= x) | |
end) | |
if found and | |
found.y == y and | |
found.x <= x and x < found.x + #found.tk then | |
return found.tk | |
end | |
end | |
local last_typeid = 0 | |
local function new_typeid() | |
last_typeid = last_typeid + 1 | |
return last_typeid | |
end | |
local TypeName = {} | |
local table_types = { | |
["array"] = true, | |
["map"] = true, | |
["arrayrecord"] = true, | |
["record"] = true, | |
["emptytable"] = true, | |
} | |
local Type = {} | |
local Operator = {} | |
local NodeKind = {} | |
local FactType = {} | |
local Fact = {} | |
local KeyParsed = {} | |
local Node = {ExpectedContext = {}, } | |
local function is_array_type(t) | |
return t.typename == "array" or t.typename == "arrayrecord" | |
end | |
local function is_record_type(t) | |
return t.typename == "record" or t.typename == "arrayrecord" | |
end | |
local function is_number_type(t) | |
return t.typename == "number" or t.typename == "integer" | |
end | |
local function is_typetype(t) | |
return t.typename == "typetype" or t.typename == "nestedtype" | |
end | |
local ParseState = {} | |
local ParseTypeListMode = {} | |
local parse_type_list | |
local parse_expression | |
local parse_expression_and_tk | |
local parse_statements | |
local parse_argument_list | |
local parse_argument_type_list | |
local parse_type | |
local parse_newtype | |
local parse_enum_body | |
local parse_record_body | |
local function fail(ps, i, msg) | |
if not ps.tokens[i] then | |
local eof = ps.tokens[#ps.tokens] | |
table.insert(ps.errs, { filename = ps.filename, y = eof.y, x = eof.x, msg = msg or "unexpected end of file" }) | |
return #ps.tokens | |
end | |
table.insert(ps.errs, { filename = ps.filename, y = ps.tokens[i].y, x = ps.tokens[i].x, msg = assert(msg, "syntax error, but no error message provided") }) | |
return math.min(#ps.tokens, i + 1) | |
end | |
local function end_at(node, tk) | |
node.yend = tk.y | |
node.xend = tk.x + #tk.tk - 1 | |
end | |
local function verify_tk(ps, i, tk) | |
if ps.tokens[i].tk == tk then | |
return i + 1 | |
end | |
return fail(ps, i, "syntax error, expected '" .. tk .. "'") | |
end | |
local function verify_end(ps, i, istart, node) | |
if ps.tokens[i].tk == "end" then | |
node.yend = ps.tokens[i].y | |
node.xend = ps.tokens[i].x + 2 | |
return i + 1 | |
end | |
end_at(node, ps.tokens[i]) | |
return fail(ps, i, "syntax error, expected 'end' to close construct started at " .. ps.filename .. ":" .. ps.tokens[istart].y .. ":" .. ps.tokens[istart].x .. ":") | |
end | |
local function new_node(tokens, i, kind) | |
local t = tokens[i] | |
return { y = t.y, x = t.x, tk = t.tk, kind = kind or t.kind } | |
end | |
local function a_type(t) | |
t.typeid = new_typeid() | |
return t | |
end | |
local function new_type(ps, i, typename) | |
local token = ps.tokens[i] | |
return a_type({ | |
typename = assert(typename), | |
filename = ps.filename, | |
y = token.y, | |
x = token.x, | |
tk = token.tk, | |
}) | |
end | |
local function verify_kind(ps, i, kind, node_kind) | |
if ps.tokens[i].kind == kind then | |
return i + 1, new_node(ps.tokens, i, node_kind) | |
end | |
return fail(ps, i, "syntax error, expected " .. kind) | |
end | |
local SkipFunction = {} | |
local function failskip(ps, i, msg, skip_fn, starti) | |
local err_ps = { | |
tokens = ps.tokens, | |
errs = {}, | |
required_modules = {}, | |
} | |
local skip_i = skip_fn(err_ps, starti or i) | |
fail(ps, starti or i, msg) | |
return skip_i or (i + 1) | |
end | |
local function skip_record(ps, i) | |
i = i + 1 | |
return parse_record_body(ps, i, {}, {}) | |
end | |
local function skip_enum(ps, i) | |
i = i + 1 | |
return parse_enum_body(ps, i, {}, {}) | |
end | |
local function parse_table_value(ps, i) | |
local next_word = ps.tokens[i].tk | |
local e | |
if next_word == "record" then | |
i = failskip(ps, i, "syntax error: this syntax is no longer valid; declare nested record inside a record", skip_record) | |
elseif next_word == "enum" then | |
i = failskip(ps, i, "syntax error: this syntax is no longer valid; declare nested enum inside a record", skip_enum) | |
else | |
i, e = parse_expression(ps, i) | |
end | |
if not e then | |
e = new_node(ps.tokens, i - 1, "error_node") | |
end | |
return i, e | |
end | |
local function parse_table_item(ps, i, n) | |
local node = new_node(ps.tokens, i, "table_item") | |
if ps.tokens[i].kind == "$EOF$" then | |
return fail(ps, i, "unexpected eof") | |
end | |
if ps.tokens[i].tk == "[" then | |
node.key_parsed = "long" | |
i = i + 1 | |
i, node.key = parse_expression_and_tk(ps, i, "]") | |
i = verify_tk(ps, i, "=") | |
i, node.value = parse_table_value(ps, i) | |
return i, node, n | |
elseif ps.tokens[i].kind == "identifier" then | |
if ps.tokens[i + 1].tk == "=" then | |
node.key_parsed = "short" | |
i, node.key = verify_kind(ps, i, "identifier", "string") | |
node.key.conststr = node.key.tk | |
node.key.tk = '"' .. node.key.tk .. '"' | |
i = verify_tk(ps, i, "=") | |
i, node.value = parse_table_value(ps, i) | |
return i, node, n | |
elseif ps.tokens[i + 1].tk == ":" then | |
node.key_parsed = "short" | |
local orig_i = i | |
local try_ps = { | |
filename = ps.filename, | |
tokens = ps.tokens, | |
errs = {}, | |
required_modules = ps.required_modules, | |
} | |
i, node.key = verify_kind(try_ps, i, "identifier", "string") | |
node.key.conststr = node.key.tk | |
node.key.tk = '"' .. node.key.tk .. '"' | |
i = verify_tk(try_ps, i, ":") | |
i, node.decltype = parse_type(try_ps, i) | |
if node.decltype and ps.tokens[i].tk == "=" then | |
i = verify_tk(try_ps, i, "=") | |
i, node.value = parse_table_value(try_ps, i) | |
if node.value then | |
for _, e in ipairs(try_ps.errs) do | |
table.insert(ps.errs, e) | |
end | |
return i, node, n | |
end | |
end | |
node.decltype = nil | |
i = orig_i | |
end | |
end | |
node.key = new_node(ps.tokens, i, "integer") | |
node.key_parsed = "implicit" | |
node.key.constnum = n | |
node.key.tk = tostring(n) | |
i, node.value = parse_expression(ps, i) | |
if not node.value then | |
return fail(ps, i, "expected an expression") | |
end | |
return i, node, n + 1 | |
end | |
local ParseItem = {} | |
local SeparatorMode = {} | |
local function parse_list(ps, i, list, close, sep, parse_item) | |
local n = 1 | |
while ps.tokens[i].kind ~= "$EOF$" do | |
if close[ps.tokens[i].tk] then | |
end_at(list, ps.tokens[i]) | |
break | |
end | |
local item | |
local oldn = n | |
i, item, n = parse_item(ps, i, n) | |
n = n or oldn | |
table.insert(list, item) | |
if ps.tokens[i].tk == "," then | |
i = i + 1 | |
if sep == "sep" and close[ps.tokens[i].tk] then | |
fail(ps, i, "unexpected '" .. ps.tokens[i].tk .. "'") | |
return i, list | |
end | |
elseif sep == "term" and ps.tokens[i].tk == ";" then | |
i = i + 1 | |
elseif not close[ps.tokens[i].tk] then | |
local options = {} | |
for k, _ in pairs(close) do | |
table.insert(options, "'" .. k .. "'") | |
end | |
table.sort(options) | |
table.insert(options, "','") | |
local expected = "syntax error, expected one of: " .. table.concat(options, ", ") | |
fail(ps, i, expected) | |
local first = options[1]:sub(2, -2) | |
if first ~= "}" and ps.tokens[i].y ~= ps.tokens[i - 1].y then | |
table.insert(ps.tokens, i, { tk = first, y = ps.tokens[i - 1].y, x = ps.tokens[i - 1].x + 1, kind = "keyword" }) | |
return i, list | |
end | |
end | |
end | |
return i, list | |
end | |
local function parse_bracket_list(ps, i, list, open, close, sep, parse_item) | |
i = verify_tk(ps, i, open) | |
i = parse_list(ps, i, list, { [close] = true }, sep, parse_item) | |
i = verify_tk(ps, i, close) | |
return i, list | |
end | |
local function parse_table_literal(ps, i) | |
local node = new_node(ps.tokens, i, "table_literal") | |
return parse_bracket_list(ps, i, node, "{", "}", "term", parse_table_item) | |
end | |
local function parse_trying_list(ps, i, list, parse_item) | |
local try_ps = { | |
filename = ps.filename, | |
tokens = ps.tokens, | |
errs = {}, | |
required_modules = ps.required_modules, | |
} | |
local tryi, item = parse_item(try_ps, i) | |
if not item then | |
return i, list | |
end | |
for _, e in ipairs(try_ps.errs) do | |
table.insert(ps.errs, e) | |
end | |
i = tryi | |
table.insert(list, item) | |
if ps.tokens[i].tk == "," then | |
while ps.tokens[i].tk == "," do | |
i = i + 1 | |
i, item = parse_item(ps, i) | |
table.insert(list, item) | |
end | |
end | |
return i, list | |
end | |
local function parse_typearg_type(ps, i) | |
local backtick = false | |
if ps.tokens[i].tk == "`" then | |
i = verify_tk(ps, i, "`") | |
backtick = true | |
end | |
i = verify_kind(ps, i, "identifier") | |
return i, a_type({ | |
y = ps.tokens[i - 2].y, | |
x = ps.tokens[i - 2].x, | |
typename = "typearg", | |
typearg = (backtick and "`" or "") .. ps.tokens[i - 1].tk, | |
}) | |
end | |
local function parse_typevar_type(ps, i) | |
i = verify_tk(ps, i, "`") | |
i = verify_kind(ps, i, "identifier") | |
return i, a_type({ | |
y = ps.tokens[i - 2].y, | |
x = ps.tokens[i - 2].x, | |
typename = "typevar", | |
typevar = "`" .. ps.tokens[i - 1].tk, | |
}) | |
end | |
local function parse_typearg_list(ps, i) | |
if ps.tokens[i + 1].tk == ">" then | |
return fail(ps, i + 1, "type argument list cannot be empty") | |
end | |
local typ = new_type(ps, i, "tuple") | |
return parse_bracket_list(ps, i, typ, "<", ">", "sep", parse_typearg_type) | |
end | |
local function parse_typeval_list(ps, i) | |
if ps.tokens[i + 1].tk == ">" then | |
return fail(ps, i + 1, "type argument list cannot be empty") | |
end | |
local typ = new_type(ps, i, "tuple") | |
return parse_bracket_list(ps, i, typ, "<", ">", "sep", parse_type) | |
end | |
local function parse_return_types(ps, i) | |
return parse_type_list(ps, i, "rets") | |
end | |
local function parse_function_type(ps, i) | |
local typ = new_type(ps, i, "function") | |
i = i + 1 | |
if ps.tokens[i].tk == "<" then | |
i, typ.typeargs = parse_typearg_list(ps, i) | |
end | |
if ps.tokens[i].tk == "(" then | |
i, typ.args = parse_argument_type_list(ps, i) | |
i, typ.rets = parse_return_types(ps, i) | |
else | |
typ.args = a_type({ typename = "tuple", is_va = true, a_type({ typename = "any" }) }) | |
typ.rets = a_type({ typename = "tuple", is_va = true, a_type({ typename = "any" }) }) | |
end | |
return i, typ | |
end | |
local NIL = a_type({ typename = "nil" }) | |
local ANY = a_type({ typename = "any" }) | |
local TABLE = a_type({ typename = "map", keys = ANY, values = ANY }) | |
local NUMBER = a_type({ typename = "number" }) | |
local STRING = a_type({ typename = "string" }) | |
local THREAD = a_type({ typename = "thread" }) | |
local BOOLEAN = a_type({ typename = "boolean" }) | |
local INTEGER = a_type({ typename = "integer" }) | |
local simple_types = { | |
["nil"] = NIL, | |
["any"] = ANY, | |
["table"] = TABLE, | |
["number"] = NUMBER, | |
["string"] = STRING, | |
["thread"] = THREAD, | |
["boolean"] = BOOLEAN, | |
["integer"] = INTEGER, | |
} | |
local function parse_base_type(ps, i) | |
local tk = ps.tokens[i].tk | |
if ps.tokens[i].kind == "identifier" then | |
local st = simple_types[tk] | |
if st then | |
return i + 1, st | |
end | |
local typ = new_type(ps, i, "nominal") | |
typ.names = { tk } | |
i = i + 1 | |
while ps.tokens[i].tk == "." do | |
i = i + 1 | |
if ps.tokens[i].kind == "identifier" then | |
table.insert(typ.names, ps.tokens[i].tk) | |
i = i + 1 | |
else | |
return fail(ps, i, "syntax error, expected identifier") | |
end | |
end | |
if ps.tokens[i].tk == "<" then | |
i, typ.typevals = parse_typeval_list(ps, i) | |
end | |
return i, typ | |
elseif tk == "{" then | |
i = i + 1 | |
local decl = new_type(ps, i, "array") | |
local t | |
i, t = parse_type(ps, i) | |
if not t then | |
return i | |
end | |
if ps.tokens[i].tk == "}" then | |
decl.elements = t | |
end_at(decl, ps.tokens[i]) | |
i = verify_tk(ps, i, "}") | |
elseif ps.tokens[i].tk == "," then | |
decl.typename = "tupletable" | |
decl.types = { t } | |
local n = 2 | |
repeat | |
i = i + 1 | |
i, decl.types[n] = parse_type(ps, i) | |
if not decl.types[n] then | |
break | |
end | |
n = n + 1 | |
until ps.tokens[i].tk ~= "," | |
end_at(decl, ps.tokens[i]) | |
i = verify_tk(ps, i, "}") | |
elseif ps.tokens[i].tk == ":" then | |
decl.typename = "map" | |
i = i + 1 | |
decl.keys = t | |
i, decl.values = parse_type(ps, i) | |
if not decl.values then | |
return i | |
end | |
end_at(decl, ps.tokens[i]) | |
i = verify_tk(ps, i, "}") | |
else | |
return fail(ps, i, "syntax error; did you forget a '}'?") | |
end | |
return i, decl | |
elseif tk == "function" then | |
return parse_function_type(ps, i) | |
elseif tk == "nil" then | |
return i + 1, simple_types["nil"] | |
elseif tk == "table" then | |
local typ = new_type(ps, i, "map") | |
typ.keys = a_type({ typename = "any" }) | |
typ.values = a_type({ typename = "any" }) | |
return i + 1, typ | |
elseif tk == "`" then | |
return parse_typevar_type(ps, i) | |
end | |
return fail(ps, i, "expected a type") | |
end | |
parse_type = function(ps, i) | |
if ps.tokens[i].tk == "(" then | |
i = i + 1 | |
local t | |
i, t = parse_type(ps, i) | |
i = verify_tk(ps, i, ")") | |
return i, t | |
end | |
local bt | |
local istart = i | |
i, bt = parse_base_type(ps, i) | |
if not bt then | |
return i | |
end | |
if ps.tokens[i].tk == "|" then | |
local u = new_type(ps, istart, "union") | |
u.types = { bt } | |
while ps.tokens[i].tk == "|" do | |
i = i + 1 | |
i, bt = parse_base_type(ps, i) | |
if not bt then | |
return i | |
end | |
table.insert(u.types, bt) | |
end | |
bt = u | |
end | |
return i, bt | |
end | |
parse_type_list = function(ps, i, mode) | |
local list = new_type(ps, i, "tuple") | |
local first_token = ps.tokens[i].tk | |
if mode == "rets" or mode == "decltype" then | |
if first_token == ":" then | |
i = i + 1 | |
else | |
return i, list | |
end | |
end | |
local optional_paren = false | |
if ps.tokens[i].tk == "(" then | |
optional_paren = true | |
i = i + 1 | |
end | |
local prev_i = i | |
i = parse_trying_list(ps, i, list, parse_type) | |
if i == prev_i and ps.tokens[i].tk ~= ")" then | |
fail(ps, i - 1, "expected a type list") | |
end | |
if mode == "rets" and ps.tokens[i].tk == "..." then | |
i = i + 1 | |
local nrets = #list | |
if nrets > 0 then | |
list.is_va = true | |
else | |
fail(ps, i, "unexpected '...'") | |
end | |
end | |
if optional_paren then | |
i = verify_tk(ps, i, ")") | |
end | |
return i, list | |
end | |
local function parse_function_args_rets_body(ps, i, node) | |
local istart = i - 1 | |
if ps.tokens[i].tk == "<" then | |
i, node.typeargs = parse_typearg_list(ps, i) | |
end | |
i, node.args = parse_argument_list(ps, i) | |
i, node.rets = parse_return_types(ps, i) | |
i, node.body = parse_statements(ps, i) | |
end_at(node, ps.tokens[i]) | |
i = verify_end(ps, i, istart, node) | |
assert(node.rets.typename == "tuple") | |
return i, node | |
end | |
local function parse_function_value(ps, i) | |
local node = new_node(ps.tokens, i, "function") | |
i = verify_tk(ps, i, "function") | |
return parse_function_args_rets_body(ps, i, node) | |
end | |
local function unquote(str) | |
local f = str:sub(1, 1) | |
if f == '"' or f == "'" then | |
return str:sub(2, -2), false | |
end | |
f = str:match("^%[=*%[") | |
local l = #f + 1 | |
return str:sub(l, -l), true | |
end | |
local function parse_literal(ps, i) | |
local tk = ps.tokens[i].tk | |
local kind = ps.tokens[i].kind | |
if kind == "identifier" then | |
return verify_kind(ps, i, "identifier", "variable") | |
elseif kind == "string" then | |
local node = new_node(ps.tokens, i, "string") | |
node.conststr, node.is_longstring = unquote(tk) | |
return i + 1, node | |
elseif kind == "number" or kind == "integer" then | |
local n = tonumber(tk) | |
local node | |
i, node = verify_kind(ps, i, kind) | |
node.constnum = n | |
return i, node | |
elseif tk == "true" then | |
return verify_kind(ps, i, "keyword", "boolean") | |
elseif tk == "false" then | |
return verify_kind(ps, i, "keyword", "boolean") | |
elseif tk == "nil" then | |
return verify_kind(ps, i, "keyword", "nil") | |
elseif tk == "function" then | |
return parse_function_value(ps, i) | |
elseif tk == "{" then | |
return parse_table_literal(ps, i) | |
elseif kind == "..." then | |
return verify_kind(ps, i, "...") | |
elseif kind == "$invalid_string$" then | |
return fail(ps, i, "malformed string") | |
elseif kind == "$invalid_number$" then | |
return fail(ps, i, "malformed number") | |
end | |
return fail(ps, i, "syntax error") | |
end | |
local function node_is_require_call(n) | |
if n.e1 and n.e2 and | |
n.e1.kind == "variable" and n.e1.tk == "require" and | |
n.e2.kind == "expression_list" and #n.e2 == 1 and | |
n.e2[1].kind == "string" then | |
return n.e2[1].conststr | |
elseif n.op and n.op.op == "@funcall" and | |
n.e1 and n.e1.tk == "pcall" and | |
n.e2 and #n.e2 == 2 and | |
n.e2[1].kind == "variable" and n.e2[1].tk == "require" and | |
n.e2[2].kind == "string" and n.e2[2].conststr then | |
return n.e2[2].conststr | |
else | |
return nil | |
end | |
end | |
local an_operator | |
do | |
local precedences = { | |
[1] = { | |
["not"] = 11, | |
["#"] = 11, | |
["-"] = 11, | |
["~"] = 11, | |
}, | |
[2] = { | |
["or"] = 1, | |
["and"] = 2, | |
["is"] = 3, | |
["<"] = 3, | |
[">"] = 3, | |
["<="] = 3, | |
[">="] = 3, | |
["~="] = 3, | |
["=="] = 3, | |
["|"] = 4, | |
["~"] = 5, | |
["&"] = 6, | |
["<<"] = 7, | |
[">>"] = 7, | |
[".."] = 8, | |
["+"] = 9, | |
["-"] = 9, | |
["*"] = 10, | |
["/"] = 10, | |
["//"] = 10, | |
["%"] = 10, | |
["^"] = 12, | |
["as"] = 50, | |
["@funcall"] = 100, | |
["@index"] = 100, | |
["."] = 100, | |
[":"] = 100, | |
}, | |
} | |
local is_right_assoc = { | |
["^"] = true, | |
[".."] = true, | |
} | |
local function new_operator(tk, arity, op) | |
op = op or tk.tk | |
return { y = tk.y, x = tk.x, arity = arity, op = op, prec = precedences[arity][op] } | |
end | |
an_operator = function(node, arity, op) | |
return { y = node.y, x = node.x, arity = arity, op = op, prec = precedences[arity][op] } | |
end | |
local args_starters = { | |
["("] = true, | |
["{"] = true, | |
["string"] = true, | |
} | |
local E | |
local function after_valid_prefixexp(ps, prevnode, i) | |
return ps.tokens[i - 1].kind == ")" or | |
(prevnode.kind == "op" and | |
(prevnode.op.op == "@funcall" or | |
prevnode.op.op == "@index" or | |
prevnode.op.op == "." or | |
prevnode.op.op == ":")) or | |
prevnode.kind == "identifier" or | |
prevnode.kind == "variable" | |
end | |
local function failstore(tkop, e1) | |
return { y = tkop.y, x = tkop.x, kind = "paren", e1 = e1, failstore = true } | |
end | |
local function P(ps, i) | |
if ps.tokens[i].kind == "$EOF$" then | |
return i | |
end | |
local e1 | |
local t1 = ps.tokens[i] | |
if precedences[1][ps.tokens[i].tk] ~= nil then | |
local op = new_operator(ps.tokens[i], 1) | |
i = i + 1 | |
local prev_i = i | |
i, e1 = P(ps, i) | |
if not e1 then | |
fail(ps, prev_i, "expected an expression") | |
return i | |
end | |
e1 = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = e1 } | |
elseif ps.tokens[i].tk == "(" then | |
i = i + 1 | |
local prev_i = i | |
i, e1 = parse_expression_and_tk(ps, i, ")") | |
if not e1 then | |
fail(ps, prev_i, "expected an expression") | |
return i | |
end | |
e1 = { y = t1.y, x = t1.x, kind = "paren", e1 = e1 } | |
else | |
i, e1 = parse_literal(ps, i) | |
end | |
if not e1 then | |
return i | |
end | |
while true do | |
local tkop = ps.tokens[i] | |
if tkop.kind == "," or tkop.kind == ")" then | |
break | |
end | |
if tkop.tk == "." or tkop.tk == ":" then | |
local op = new_operator(tkop, 2) | |
local prev_i = i | |
local key | |
i = i + 1 | |
i, key = verify_kind(ps, i, "identifier") | |
if not key then | |
return i, failstore(tkop, e1) | |
end | |
if op.op == ":" then | |
if not args_starters[ps.tokens[i].kind] then | |
fail(ps, i, "expected a function call for a method") | |
return i, failstore(tkop, e1) | |
end | |
if not after_valid_prefixexp(ps, e1, prev_i) then | |
fail(ps, prev_i, "cannot call a method on this expression") | |
return i, failstore(tkop, e1) | |
end | |
end | |
e1 = { y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = key } | |
elseif tkop.tk == "(" then | |
local op = new_operator(tkop, 2, "@funcall") | |
local prev_i = i | |
local args = new_node(ps.tokens, i, "expression_list") | |
i, args = parse_bracket_list(ps, i, args, "(", ")", "sep", parse_expression) | |
if not after_valid_prefixexp(ps, e1, prev_i) then | |
fail(ps, prev_i, "cannot call this expression") | |
return i, failstore(tkop, e1) | |
end | |
e1 = { y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args } | |
table.insert(ps.required_modules, node_is_require_call(e1)) | |
elseif tkop.tk == "[" then | |
local op = new_operator(tkop, 2, "@index") | |
local prev_i = i | |
local idx | |
i = i + 1 | |
i, idx = parse_expression_and_tk(ps, i, "]") | |
if not after_valid_prefixexp(ps, e1, prev_i) then | |
fail(ps, prev_i, "cannot index this expression") | |
return i, failstore(tkop, e1) | |
end | |
e1 = { y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = idx } | |
elseif tkop.kind == "string" or tkop.kind == "{" then | |
local op = new_operator(tkop, 2, "@funcall") | |
local prev_i = i | |
local args = new_node(ps.tokens, i, "expression_list") | |
local argument | |
if tkop.kind == "string" then | |
argument = new_node(ps.tokens, i) | |
argument.conststr = unquote(tkop.tk) | |
i = i + 1 | |
else | |
i, argument = parse_table_literal(ps, i) | |
end | |
if not after_valid_prefixexp(ps, e1, prev_i) then | |
if tkop.kind == "string" then | |
fail(ps, prev_i, "cannot use a string here; if you're trying to call the previous expression, wrap it in parentheses") | |
else | |
fail(ps, prev_i, "cannot use a table here; if you're trying to call the previous expression, wrap it in parentheses") | |
end | |
return i, failstore(tkop, e1) | |
end | |
table.insert(args, argument) | |
e1 = { y = args.y, x = args.x, kind = "op", op = op, e1 = e1, e2 = args } | |
table.insert(ps.required_modules, node_is_require_call(e1)) | |
elseif tkop.tk == "as" or tkop.tk == "is" then | |
local op = new_operator(tkop, 2, tkop.tk) | |
i = i + 1 | |
local cast = new_node(ps.tokens, i, "cast") | |
if ps.tokens[i].tk == "(" then | |
i, cast.casttype = parse_type_list(ps, i, "casttype") | |
else | |
i, cast.casttype = parse_type(ps, i) | |
end | |
if not cast.casttype then | |
return i, failstore(tkop, e1) | |
end | |
e1 = { y = tkop.y, x = tkop.x, kind = "op", op = op, e1 = e1, e2 = cast, conststr = e1.conststr } | |
else | |
break | |
end | |
end | |
return i, e1 | |
end | |
E = function(ps, i, lhs, min_precedence) | |
local lookahead = ps.tokens[i].tk | |
while precedences[2][lookahead] and precedences[2][lookahead] >= min_precedence do | |
local t1 = ps.tokens[i] | |
local op = new_operator(t1, 2) | |
i = i + 1 | |
local rhs | |
i, rhs = P(ps, i) | |
if not rhs then | |
fail(ps, i, "expected an expression") | |
return i | |
end | |
lookahead = ps.tokens[i].tk | |
while precedences[2][lookahead] and ((precedences[2][lookahead] > (precedences[2][op.op])) or | |
(is_right_assoc[lookahead] and (precedences[2][lookahead] == precedences[2][op.op]))) do | |
i, rhs = E(ps, i, rhs, precedences[2][lookahead]) | |
if not rhs then | |
fail(ps, i, "expected an expression") | |
return i | |
end | |
lookahead = ps.tokens[i].tk | |
end | |
lhs = { y = t1.y, x = t1.x, kind = "op", op = op, e1 = lhs, e2 = rhs } | |
end | |
return i, lhs | |
end | |
parse_expression = function(ps, i) | |
local lhs | |
local istart = i | |
i, lhs = P(ps, i) | |
if lhs then | |
i, lhs = E(ps, i, lhs, 0) | |
end | |
if lhs then | |
return i, lhs, 0 | |
end | |
if i == istart then | |
i = fail(ps, i, "expected an expression") | |
end | |
return i | |
end | |
end | |
parse_expression_and_tk = function(ps, i, tk) | |
local e | |
i, e = parse_expression(ps, i) | |
if not e then | |
e = new_node(ps.tokens, i - 1, "error_node") | |
end | |
if ps.tokens[i].tk == tk then | |
i = i + 1 | |
else | |
for n = 0, 19 do | |
local t = ps.tokens[i + n] | |
if t.kind == "$EOF$" then | |
break | |
end | |
if t.tk == tk then | |
fail(ps, i, "syntax error, expected '" .. tk .. "'") | |
return i + n + 1, e | |
end | |
end | |
i = fail(ps, i, "syntax error, expected '" .. tk .. "'") | |
end | |
return i, e | |
end | |
local function parse_variable_name(ps, i) | |
local is_const = false | |
local node | |
i, node = verify_kind(ps, i, "identifier") | |
if not node then | |
return i | |
end | |
if ps.tokens[i].tk == "<" then | |
i = i + 1 | |
local annotation | |
i, annotation = verify_kind(ps, i, "identifier") | |
if annotation then | |
if annotation.tk == "const" then | |
is_const = true | |
else | |
fail(ps, i, "unknown variable annotation: " .. annotation.tk) | |
end | |
else | |
fail(ps, i, "expected a variable annotation") | |
end | |
i = verify_tk(ps, i, ">") | |
end | |
node.is_const = is_const | |
return i, node | |
end | |
local function parse_argument(ps, i) | |
local node | |
if ps.tokens[i].tk == "..." then | |
i, node = verify_kind(ps, i, "...", "argument") | |
else | |
i, node = verify_kind(ps, i, "identifier", "argument") | |
end | |
if ps.tokens[i].tk == ":" then | |
i = i + 1 | |
local decltype | |
i, decltype = parse_type(ps, i) | |
if node then | |
node.decltype = decltype | |
end | |
end | |
return i, node, 0 | |
end | |
parse_argument_list = function(ps, i) | |
local node = new_node(ps.tokens, i, "argument_list") | |
i, node = parse_bracket_list(ps, i, node, "(", ")", "sep", parse_argument) | |
for a, fnarg in ipairs(node) do | |
if fnarg.tk == "..." and a ~= #node then | |
fail(ps, i, "'...' can only be last argument") | |
end | |
end | |
return i, node | |
end | |
local function parse_argument_type(ps, i) | |
local is_va = false | |
if ps.tokens[i].kind == "identifier" and ps.tokens[i + 1].tk == ":" then | |
i = i + 2 | |
elseif ps.tokens[i].tk == "..." then | |
if ps.tokens[i + 1].tk == ":" then | |
i = i + 2 | |
is_va = true | |
else | |
return fail(ps, i, "cannot have untyped '...' when declaring the type of an argument") | |
end | |
end | |
local typ; i, typ = parse_type(ps, i) | |
if typ then | |
typ.is_va = is_va | |
end | |
return i, typ, 0 | |
end | |
parse_argument_type_list = function(ps, i) | |
local list = new_type(ps, i, "tuple") | |
i = parse_bracket_list(ps, i, list, "(", ")", "sep", parse_argument_type) | |
if list[#list] and list[#list].is_va then | |
list[#list].is_va = nil | |
list.is_va = true | |
end | |
return i, list | |
end | |
local function parse_identifier(ps, i) | |
if ps.tokens[i].kind == "identifier" then | |
return i + 1, new_node(ps.tokens, i, "identifier") | |
end | |
i = fail(ps, i, "syntax error, expected identifier") | |
return i, new_node(ps.tokens, i, "error_node") | |
end | |
local function parse_local_function(ps, i) | |
i = verify_tk(ps, i, "local") | |
i = verify_tk(ps, i, "function") | |
local node = new_node(ps.tokens, i, "local_function") | |
i, node.name = parse_identifier(ps, i) | |
return parse_function_args_rets_body(ps, i, node) | |
end | |
local function parse_global_function(ps, i) | |
local orig_i = i | |
i = verify_tk(ps, i, "function") | |
local fn = new_node(ps.tokens, i, "global_function") | |
local names = {} | |
i, names[1] = parse_identifier(ps, i) | |
while ps.tokens[i].tk == "." do | |
i = i + 1 | |
i, names[#names + 1] = parse_identifier(ps, i) | |
end | |
if ps.tokens[i].tk == ":" then | |
i = i + 1 | |
i, names[#names + 1] = parse_identifier(ps, i) | |
fn.is_method = true | |
end | |
if #names > 1 then | |
fn.kind = "record_function" | |
local owner = names[1] | |
owner.kind = "type_identifier" | |
for i2 = 2, #names - 1 do | |
local dot = an_operator(names[i2], 2, ".") | |
names[i2].kind = "identifier" | |
owner = { y = names[i2].y, x = names[i2].x, kind = "op", op = dot, e1 = owner, e2 = names[i2] } | |
end | |
fn.fn_owner = owner | |
end | |
fn.name = names[#names] | |
local selfx, selfy = ps.tokens[i].x, ps.tokens[i].y | |
i = parse_function_args_rets_body(ps, i, fn) | |
if fn.is_method then | |
table.insert(fn.args, 1, { x = selfx, y = selfy, tk = "self", kind = "identifier" }) | |
end | |
if not fn.name then | |
return orig_i + 1 | |
end | |
return i, fn | |
end | |
local function parse_if_block(ps, i, n, node, is_else) | |
local block = new_node(ps.tokens, i, "if_block") | |
i = i + 1 | |
block.if_parent = node | |
block.if_block_n = n | |
if not is_else then | |
i, block.exp = parse_expression_and_tk(ps, i, "then") | |
if not block.exp then | |
return i | |
end | |
end | |
i, block.body = parse_statements(ps, i) | |
if not block.body then | |
return i | |
end | |
end_at(block.body, ps.tokens[i - 1]) | |
block.yend, block.xend = block.body.yend, block.body.xend | |
table.insert(node.if_blocks, block) | |
return i, node | |
end | |
local function parse_if(ps, i) | |
local istart = i | |
local node = new_node(ps.tokens, i, "if") | |
node.if_blocks = {} | |
i, node = parse_if_block(ps, i, 1, node) | |
if not node then | |
return i | |
end | |
local n = 2 | |
while ps.tokens[i].tk == "elseif" do | |
i, node = parse_if_block(ps, i, n, node) | |
if not node then | |
return i | |
end | |
n = n + 1 | |
end | |
if ps.tokens[i].tk == "else" then | |
i, node = parse_if_block(ps, i, n, node, true) | |
if not node then | |
return i | |
end | |
end | |
i = verify_end(ps, i, istart, node) | |
return i, node | |
end | |
local function parse_while(ps, i) | |
local istart = i | |
local node = new_node(ps.tokens, i, "while") | |
i = verify_tk(ps, i, "while") | |
i, node.exp = parse_expression_and_tk(ps, i, "do") | |
i, node.body = parse_statements(ps, i) | |
i = verify_end(ps, i, istart, node) | |
return i, node | |
end | |
local function parse_fornum(ps, i) | |
local istart = i | |
local node = new_node(ps.tokens, i, "fornum") | |
i = i + 1 | |
i, node.var = parse_identifier(ps, i) | |
i = verify_tk(ps, i, "=") | |
i, node.from = parse_expression_and_tk(ps, i, ",") | |
i, node.to = parse_expression(ps, i) | |
if ps.tokens[i].tk == "," then | |
i = i + 1 | |
i, node.step = parse_expression_and_tk(ps, i, "do") | |
else | |
i = verify_tk(ps, i, "do") | |
end | |
i, node.body = parse_statements(ps, i) | |
i = verify_end(ps, i, istart, node) | |
return i, node | |
end | |
local function parse_forin(ps, i) | |
local istart = i | |
local node = new_node(ps.tokens, i, "forin") | |
i = i + 1 | |
node.vars = new_node(ps.tokens, i, "variable_list") | |
i, node.vars = parse_list(ps, i, node.vars, { ["in"] = true }, "sep", parse_variable_name) | |
i = verify_tk(ps, i, "in") | |
node.exps = new_node(ps.tokens, i, "expression_list") | |
i = parse_list(ps, i, node.exps, { ["do"] = true }, "sep", parse_expression) | |
if #node.exps < 1 then | |
return fail(ps, i, "missing iterator expression in generic for") | |
elseif #node.exps > 3 then | |
return fail(ps, i, "too many expressions in generic for") | |
end | |
i = verify_tk(ps, i, "do") | |
i, node.body = parse_statements(ps, i) | |
i = verify_end(ps, i, istart, node) | |
return i, node | |
end | |
local function parse_for(ps, i) | |
if ps.tokens[i + 1].kind == "identifier" and ps.tokens[i + 2].tk == "=" then | |
return parse_fornum(ps, i) | |
else | |
return parse_forin(ps, i) | |
end | |
end | |
local function parse_repeat(ps, i) | |
local node = new_node(ps.tokens, i, "repeat") | |
i = verify_tk(ps, i, "repeat") | |
i, node.body = parse_statements(ps, i) | |
node.body.is_repeat = true | |
i = verify_tk(ps, i, "until") | |
i, node.exp = parse_expression(ps, i) | |
end_at(node, ps.tokens[i - 1]) | |
return i, node | |
end | |
local function parse_do(ps, i) | |
local istart = i | |
local node = new_node(ps.tokens, i, "do") | |
i = verify_tk(ps, i, "do") | |
i, node.body = parse_statements(ps, i) | |
i = verify_end(ps, i, istart, node) | |
return i, node | |
end | |
local function parse_break(ps, i) | |
local node = new_node(ps.tokens, i, "break") | |
i = verify_tk(ps, i, "break") | |
return i, node | |
end | |
local function parse_goto(ps, i) | |
local node = new_node(ps.tokens, i, "goto") | |
i = verify_tk(ps, i, "goto") | |
node.label = ps.tokens[i].tk | |
i = verify_kind(ps, i, "identifier") | |
return i, node | |
end | |
local function parse_label(ps, i) | |
local node = new_node(ps.tokens, i, "label") | |
i = verify_tk(ps, i, "::") | |
node.label = ps.tokens[i].tk | |
i = verify_kind(ps, i, "identifier") | |
i = verify_tk(ps, i, "::") | |
return i, node | |
end | |
local stop_statement_list = { | |
["end"] = true, | |
["else"] = true, | |
["elseif"] = true, | |
["until"] = true, | |
} | |
local stop_return_list = { | |
[";"] = true, | |
["$EOF$"] = true, | |
} | |
for k, v in pairs(stop_statement_list) do | |
stop_return_list[k] = v | |
end | |
local function parse_return(ps, i) | |
local node = new_node(ps.tokens, i, "return") | |
i = verify_tk(ps, i, "return") | |
node.exps = new_node(ps.tokens, i, "expression_list") | |
i = parse_list(ps, i, node.exps, stop_return_list, "sep", parse_expression) | |
if ps.tokens[i].kind == ";" then | |
i = i + 1 | |
end | |
return i, node | |
end | |
local function store_field_in_record(ps, i, field_name, t, fields, field_order) | |
if not fields[field_name] then | |
fields[field_name] = t | |
table.insert(field_order, field_name) | |
else | |
local prev_t = fields[field_name] | |
if t.typename == "function" and prev_t.typename == "function" then | |
fields[field_name] = new_type(ps, i, "poly") | |
fields[field_name].types = { prev_t, t } | |
elseif t.typename == "function" and prev_t.typename == "poly" then | |
table.insert(prev_t.types, t) | |
else | |
fail(ps, i, "attempt to redeclare field '" .. field_name .. "' (only functions can be overloaded)") | |
return false | |
end | |
end | |
return true | |
end | |
local ParseBody = {} | |
local function parse_nested_type(ps, i, def, typename, parse_body) | |
i = i + 1 | |
local v | |
i, v = verify_kind(ps, i, "identifier", "type_identifier") | |
if not v then | |
return fail(ps, i, "expected a variable name") | |
end | |
local nt = new_node(ps.tokens, i, "newtype") | |
nt.newtype = new_type(ps, i, "typetype") | |
local rdef = new_type(ps, i, typename) | |
local iok = parse_body(ps, i, rdef, nt) | |
if iok then | |
i = iok | |
nt.newtype.def = rdef | |
end | |
store_field_in_record(ps, i, v.tk, nt.newtype, def.fields, def.field_order) | |
return i | |
end | |
parse_enum_body = function(ps, i, def, node) | |
local istart = i - 1 | |
def.enumset = {} | |
while ps.tokens[i].tk ~= "$EOF$" and ps.tokens[i].tk ~= "end" do | |
local item | |
i, item = verify_kind(ps, i, "string", "enum_item") | |
if item then | |
table.insert(node, item) | |
def.enumset[unquote(item.tk)] = true | |
end | |
end | |
i = verify_end(ps, i, istart, node) | |
return i, node | |
end | |
local metamethod_names = { | |
["__add"] = true, | |
["__sub"] = true, | |
["__mul"] = true, | |
["__div"] = true, | |
["__mod"] = true, | |
["__pow"] = true, | |
["__unm"] = true, | |
["__idiv"] = true, | |
["__band"] = true, | |
["__bor"] = true, | |
["__bxor"] = true, | |
["__bnot"] = true, | |
["__shl"] = true, | |
["__shr"] = true, | |
["__concat"] = true, | |
["__len"] = true, | |
["__eq"] = true, | |
["__lt"] = true, | |
["__le"] = true, | |
["__index"] = true, | |
["__newindex"] = true, | |
["__call"] = true, | |
["__tostring"] = true, | |
["__pairs"] = true, | |
["__gc"] = true, | |
} | |
parse_record_body = function(ps, i, def, node) | |
local istart = i - 1 | |
def.fields = {} | |
def.field_order = {} | |
if ps.tokens[i].tk == "<" then | |
i, def.typeargs = parse_typearg_list(ps, i) | |
end | |
while not (ps.tokens[i].kind == "$EOF$" or ps.tokens[i].tk == "end") do | |
if ps.tokens[i].tk == "userdata" and ps.tokens[i + 1].tk ~= ":" then | |
if def.is_userdata then | |
fail(ps, i, "duplicated 'userdata' declaration in record") | |
else | |
def.is_userdata = true | |
end | |
i = i + 1 | |
elseif ps.tokens[i].tk == "{" then | |
if def.typename == "arrayrecord" then | |
i = failskip(ps, i, "duplicated declaration of array element type in record", parse_type) | |
else | |
i = i + 1 | |
local t | |
i, t = parse_type(ps, i) | |
if ps.tokens[i].tk == "}" then | |
i = verify_tk(ps, i, "}") | |
else | |
return fail(ps, i, "expected an array declaration") | |
end | |
def.typename = "arrayrecord" | |
def.elements = t | |
end | |
elseif ps.tokens[i].tk == "type" and ps.tokens[i + 1].tk ~= ":" then | |
i = i + 1 | |
local iv = i | |
local v | |
i, v = verify_kind(ps, i, "identifier", "type_identifier") | |
if not v then | |
return fail(ps, i, "expected a variable name") | |
end | |
i = verify_tk(ps, i, "=") | |
local nt | |
i, nt = parse_newtype(ps, i) | |
if not nt or not nt.newtype then | |
return fail(ps, i, "expected a type definition") | |
end | |
store_field_in_record(ps, iv, v.tk, nt.newtype, def.fields, def.field_order) | |
elseif ps.tokens[i].tk == "record" and ps.tokens[i + 1].tk ~= ":" then | |
i = parse_nested_type(ps, i, def, "record", parse_record_body) | |
elseif ps.tokens[i].tk == "enum" and ps.tokens[i + 1].tk ~= ":" then | |
i = parse_nested_type(ps, i, def, "enum", parse_enum_body) | |
else | |
local is_metamethod = false | |
if ps.tokens[i].tk == "metamethod" and ps.tokens[i + 1].tk ~= ":" then | |
is_metamethod = true | |
i = i + 1 | |
end | |
local v | |
if ps.tokens[i].tk == "[" then | |
i, v = parse_literal(ps, i + 1) | |
if v and not v.conststr then | |
return fail(ps, i, "expected a string literal") | |
end | |
i = verify_tk(ps, i, "]") | |
else | |
i, v = verify_kind(ps, i, "identifier", "variable") | |
end | |
local iv = i | |
if not v then | |
return fail(ps, i, "expected a variable name") | |
end | |
if ps.tokens[i].tk == ":" then | |
i = i + 1 | |
local t | |
i, t = parse_type(ps, i) | |
if not t then | |
return fail(ps, i, "expected a type") | |
end | |
local field_name = v.conststr or v.tk | |
local fields = def.fields | |
local field_order = def.field_order | |
if is_metamethod then | |
if not def.meta_fields then | |
def.meta_fields = {} | |
def.meta_field_order = {} | |
end | |
fields = def.meta_fields | |
field_order = def.meta_field_order | |
if not metamethod_names[field_name] then | |
fail(ps, i - 1, "not a valid metamethod: " .. field_name) | |
end | |
end | |
store_field_in_record(ps, iv, field_name, t, fields, field_order) | |
elseif ps.tokens[i].tk == "=" then | |
local next_word = ps.tokens[i + 1].tk | |
if next_word == "record" or next_word == "enum" then | |
return fail(ps, i, "syntax error: this syntax is no longer valid; use '" .. next_word .. " " .. v.tk .. "'") | |
elseif next_word == "functiontype" then | |
return fail(ps, i, "syntax error: this syntax is no longer valid; use 'type " .. v.tk .. " = function('...") | |
else | |
return fail(ps, i, "syntax error: this syntax is no longer valid; use 'type " .. v.tk .. " = '...") | |
end | |
else | |
fail(ps, i, "syntax error: expected ':' for an attribute or '=' for a nested type") | |
end | |
end | |
end | |
i = verify_end(ps, i, istart, node) | |
return i, node | |
end | |
parse_newtype = function(ps, i) | |
local node = new_node(ps.tokens, i, "newtype") | |
node.newtype = new_type(ps, i, "typetype") | |
if ps.tokens[i].tk == "record" then | |
local def = new_type(ps, i, "record") | |
i = i + 1 | |
i = parse_record_body(ps, i, def, node) | |
node.newtype.def = def | |
return i, node | |
elseif ps.tokens[i].tk == "enum" then | |
local def = new_type(ps, i, "enum") | |
i = i + 1 | |
i = parse_enum_body(ps, i, def, node) | |
node.newtype.def = def | |
return i, node | |
else | |
i, node.newtype.def = parse_type(ps, i) | |
if not node.newtype.def then | |
return i | |
end | |
return i, node | |
end | |
return fail(ps, i, "expected a type") | |
end | |
local function parse_assignment_expression_list(ps, i, asgn) | |
asgn.exps = new_node(ps.tokens, i, "expression_list") | |
repeat | |
i = i + 1 | |
local val | |
i, val = parse_expression(ps, i) | |
if not val then | |
if #asgn.exps == 0 then | |
asgn.exps = nil | |
end | |
return i | |
end | |
table.insert(asgn.exps, val) | |
until ps.tokens[i].tk ~= "," | |
return i, asgn | |
end | |
local parse_call_or_assignment | |
do | |
local function is_lvalue(node) | |
return node.kind == "variable" or | |
(node.kind == "op" and (node.op.op == "@index" or node.op.op == ".")) | |
end | |
local function parse_variable(ps, i) | |
local node | |
i, node = parse_expression(ps, i) | |
if not (node and is_lvalue(node)) then | |
return fail(ps, i, "expected a variable") | |
end | |
return i, node | |
end | |
parse_call_or_assignment = function(ps, i) | |
local exp | |
local istart = i | |
i, exp = parse_expression(ps, i) | |
if not exp then | |
return i | |
end | |
if (exp.op and exp.op.op == "@funcall") or exp.failstore then | |
return i, exp | |
end | |
if not is_lvalue(exp) then | |
return fail(ps, i, "syntax error") | |
end | |
local asgn = new_node(ps.tokens, istart, "assignment") | |
asgn.vars = new_node(ps.tokens, istart, "variable_list") | |
asgn.vars[1] = exp | |
if ps.tokens[i].tk == "," then | |
i = i + 1 | |
i = parse_trying_list(ps, i, asgn.vars, parse_variable) | |
if #asgn.vars < 2 then | |
return fail(ps, i, "syntax error") | |
end | |
end | |
if ps.tokens[i].tk ~= "=" then | |
verify_tk(ps, i, "=") | |
return i | |
end | |
i, asgn = parse_assignment_expression_list(ps, i, asgn) | |
return i, asgn | |
end | |
end | |
local function parse_variable_declarations(ps, i, node_name) | |
local asgn = new_node(ps.tokens, i, node_name) | |
asgn.vars = new_node(ps.tokens, i, "variable_list") | |
i = parse_trying_list(ps, i, asgn.vars, parse_variable_name) | |
if #asgn.vars == 0 then | |
return fail(ps, i, "expected a local variable definition") | |
end | |
i, asgn.decltype = parse_type_list(ps, i, "decltype") | |
if ps.tokens[i].tk == "=" then | |
local next_word = ps.tokens[i + 1].tk | |
if next_word == "record" then | |
local scope = node_name == "local_declaration" and "local" or "global" | |
return failskip(ps, i + 1, "syntax error: this syntax is no longer valid; use '" .. scope .. " record " .. asgn.vars[1].tk .. "'", skip_record) | |
elseif next_word == "enum" then | |
local scope = node_name == "local_declaration" and "local" or "global" | |
return failskip(ps, i + 1, "syntax error: this syntax is no longer valid; use '" .. scope .. " enum " .. asgn.vars[1].tk .. "'", skip_enum) | |
elseif next_word == "functiontype" then | |
local scope = node_name == "local_declaration" and "local" or "global" | |
return failskip(ps, i + 1, "syntax error: this syntax is no longer valid; use '" .. scope .. " type " .. asgn.vars[1].tk .. " = function('...", parse_function_type) | |
end | |
i, asgn = parse_assignment_expression_list(ps, i, asgn) | |
end | |
return i, asgn | |
end | |
local function parse_type_declaration(ps, i, node_name) | |
i = i + 2 | |
local asgn = new_node(ps.tokens, i, node_name) | |
i, asgn.var = parse_variable_name(ps, i) | |
if not asgn.var then | |
return fail(ps, i, "expected a type name") | |
end | |
i = verify_tk(ps, i, "=") | |
i, asgn.value = parse_newtype(ps, i) | |
if not asgn.value then | |
return i | |
end | |
if not asgn.value.newtype.def.names then | |
asgn.value.newtype.def.names = { asgn.var.tk } | |
end | |
return i, asgn | |
end | |
local function parse_type_constructor(ps, i, node_name, type_name, parse_body) | |
local asgn = new_node(ps.tokens, i, node_name) | |
local nt = new_node(ps.tokens, i, "newtype") | |
asgn.value = nt | |
nt.newtype = new_type(ps, i, "typetype") | |
local def = new_type(ps, i, type_name) | |
nt.newtype.def = def | |
i = i + 2 | |
i, asgn.var = verify_kind(ps, i, "identifier") | |
if not asgn.var then | |
return fail(ps, i, "expected a type name") | |
end | |
nt.newtype.def.names = { asgn.var.tk } | |
i = parse_body(ps, i, def, nt) | |
return i, asgn | |
end | |
local function skip_type_declaration(ps, i) | |
return (parse_type_declaration(ps, i - 1, "local_type")) | |
end | |
local function parse_local(ps, i) | |
local ntk = ps.tokens[i + 1].tk | |
if ntk == "function" then | |
return parse_local_function(ps, i) | |
elseif ntk == "type" and ps.tokens[i + 2].kind == "identifier" then | |
return parse_type_declaration(ps, i, "local_type") | |
elseif ntk == "record" and ps.tokens[i + 2].kind == "identifier" then | |
return parse_type_constructor(ps, i, "local_type", "record", parse_record_body) | |
elseif ntk == "enum" and ps.tokens[i + 2].kind == "identifier" then | |
return parse_type_constructor(ps, i, "local_type", "enum", parse_enum_body) | |
end | |
return parse_variable_declarations(ps, i + 1, "local_declaration") | |
end | |
local function parse_global(ps, i) | |
local ntk = ps.tokens[i + 1].tk | |
if ntk == "function" then | |
return parse_global_function(ps, i + 1) | |
elseif ntk == "type" and ps.tokens[i + 2].kind == "identifier" then | |
return parse_type_declaration(ps, i, "global_type") | |
elseif ntk == "record" and ps.tokens[i + 2].kind == "identifier" then | |
return parse_type_constructor(ps, i, "global_type", "record", parse_record_body) | |
elseif ntk == "enum" and ps.tokens[i + 2].kind == "identifier" then | |
return parse_type_constructor(ps, i, "global_type", "enum", parse_enum_body) | |
elseif ps.tokens[i + 1].kind == "identifier" then | |
return parse_variable_declarations(ps, i + 1, "global_declaration") | |
end | |
return parse_call_or_assignment(ps, i) | |
end | |
local function parse_type_statement(ps, i) | |
if ps.tokens[i + 1].kind == "identifier" then | |
return failskip(ps, i, "types need to be declared with 'local type' or 'global type'", skip_type_declaration) | |
end | |
return parse_call_or_assignment(ps, i) | |
end | |
local parse_statement_fns = { | |
["::"] = parse_label, | |
["do"] = parse_do, | |
["if"] = parse_if, | |
["for"] = parse_for, | |
["goto"] = parse_goto, | |
["type"] = parse_type_statement, | |
["local"] = parse_local, | |
["while"] = parse_while, | |
["break"] = parse_break, | |
["global"] = parse_global, | |
["repeat"] = parse_repeat, | |
["return"] = parse_return, | |
["function"] = parse_global_function, | |
} | |
parse_statements = function(ps, i, toplevel) | |
local node = new_node(ps.tokens, i, "statements") | |
local item | |
while true do | |
while ps.tokens[i].kind == ";" do | |
i = i + 1 | |
if item then | |
item.semicolon = true | |
end | |
end | |
if ps.tokens[i].kind == "$EOF$" then | |
break | |
end | |
if (not toplevel) and stop_statement_list[ps.tokens[i].tk] then | |
break | |
end | |
local parse_statement_fn = parse_statement_fns[ps.tokens[i].tk] or parse_call_or_assignment | |
i, item = parse_statement_fn(ps, i) | |
if item then | |
table.insert(node, item) | |
elseif i > 1 then | |
local lasty = ps.tokens[i - 1].y | |
while ps.tokens[i].kind ~= "$EOF$" and ps.tokens[i].y == lasty do | |
i = i + 1 | |
end | |
end | |
end | |
end_at(node, ps.tokens[i]) | |
return i, node | |
end | |
local function clear_redundant_errors(errors) | |
local redundant = {} | |
local lastx, lasty = 0, 0 | |
for i, err in ipairs(errors) do | |
err.i = i | |
end | |
table.sort(errors, function(a, b) | |
local af = a.filename or "" | |
local bf = b.filename or "" | |
return af < bf or | |
(af == bf and (a.y < b.y or | |
(a.y == b.y and (a.x < b.x or | |
(a.x == b.x and (a.i < b.i)))))) | |
end) | |
for i, err in ipairs(errors) do | |
err.i = nil | |
if err.x == lastx and err.y == lasty then | |
table.insert(redundant, i) | |
end | |
lastx, lasty = err.x, err.y | |
end | |
for i = #redundant, 1, -1 do | |
table.remove(errors, redundant[i]) | |
end | |
end | |
function tl.parse_program(tokens, errs, filename) | |
errs = errs or {} | |
local ps = { | |
tokens = tokens, | |
errs = errs, | |
filename = filename or "", | |
required_modules = {}, | |
} | |
local i, node = parse_statements(ps, 1, true) | |
clear_redundant_errors(errs) | |
return i, node, ps.required_modules | |
end | |
local VisitorCallbacks = {} | |
local VisitorExtraCallback = {} | |
local Visitor = {} | |
local MetaMode = {} | |
local function fields_of(t, meta) | |
local i = 1 | |
local field_order = meta and t.meta_field_order or t.field_order | |
local fields = meta and t.meta_fields or t.fields | |
return function() | |
local name = field_order[i] | |
if not name then | |
return nil | |
end | |
i = i + 1 | |
return name, fields[name] | |
end | |
end | |
local function recurse_type(ast, visit) | |
local kind = ast.typename | |
local cbs = visit.cbs | |
local cbkind = cbs and cbs[kind] | |
do | |
if cbkind then | |
local cbkind_before = cbkind.before | |
if cbkind_before then | |
cbkind_before(ast) | |
end | |
else | |
if cbs then | |
error("internal compiler error: no visitor for " .. kind) | |
end | |
end | |
end | |
local xs = {} | |
if ast.typeargs then | |
for _, child in ipairs(ast.typeargs) do | |
table.insert(xs, recurse_type(child, visit)) | |
end | |
end | |
for i, child in ipairs(ast) do | |
xs[i] = recurse_type(child, visit) | |
end | |
if ast.types then | |
for _, child in ipairs(ast.types) do | |
table.insert(xs, recurse_type(child, visit)) | |
end | |
end | |
if ast.def then | |
table.insert(xs, recurse_type(ast.def, visit)) | |
end | |
if ast.keys then | |
table.insert(xs, recurse_type(ast.keys, visit)) | |
end | |
if ast.values then | |
table.insert(xs, recurse_type(ast.values, visit)) | |
end | |
if ast.elements then | |
table.insert(xs, recurse_type(ast.elements, visit)) | |
end | |
if ast.fields then | |
for _, child in fields_of(ast) do | |
table.insert(xs, recurse_type(child, visit)) | |
end | |
end | |
if ast.meta_fields then | |
for _, child in fields_of(ast, "meta") do | |
table.insert(xs, recurse_type(child, visit)) | |
end | |
end | |
if ast.args then | |
for i, child in ipairs(ast.args) do | |
if i > 1 or not ast.is_method then | |
table.insert(xs, recurse_type(child, visit)) | |
end | |
end | |
end | |
if ast.rets then | |
for _, child in ipairs(ast.rets) do | |
table.insert(xs, recurse_type(child, visit)) | |
end | |
end | |
if ast.typevals then | |
for _, child in ipairs(ast.typevals) do | |
table.insert(xs, recurse_type(child, visit)) | |
end | |
end | |
if ast.ktype then | |
table.insert(xs, recurse_type(ast.ktype, visit)) | |
end | |
if ast.vtype then | |
table.insert(xs, recurse_type(ast.vtype, visit)) | |
end | |
local ret | |
do | |
local cbkind_after = cbkind and cbkind.after | |
if cbkind_after then | |
ret = cbkind_after(ast, xs) | |
end | |
local visit_after = visit.after | |
if visit_after then | |
ret = visit_after(ast, xs, ret) | |
end | |
end | |
return ret | |
end | |
local function recurse_typeargs(ast, visit_type) | |
if ast.typeargs then | |
for _, typearg in ipairs(ast.typeargs) do | |
recurse_type(typearg, visit_type) | |
end | |
end | |
end | |
local function extra_callback(name, | |
ast, | |
xs, | |
visit_node) | |
local cbs = visit_node.cbs | |
if not cbs then return end | |
local nbs = cbs[ast.kind] | |
if not nbs then return end | |
local bs = nbs[name] | |
if not bs then return end | |
bs(ast, xs) | |
end | |
local no_recurse_node = { | |
["..."] = true, | |
["nil"] = true, | |
["cast"] = true, | |
["goto"] = true, | |
["break"] = true, | |
["label"] = true, | |
["number"] = true, | |
["string"] = true, | |
["boolean"] = true, | |
["integer"] = true, | |
["variable"] = true, | |
["error_node"] = true, | |
["identifier"] = true, | |
["type_identifier"] = true, | |
} | |
local function recurse_node(root, | |
visit_node, | |
visit_type) | |
if not root then | |
return | |
end | |
local recurse | |
local function walk_children(ast, xs) | |
for i, child in ipairs(ast) do | |
xs[i] = recurse(child) | |
end | |
end | |
local function walk_vars_exps(ast, xs) | |
xs[1] = recurse(ast.vars) | |
if ast.decltype then | |
xs[2] = recurse_type(ast.decltype, visit_type) | |
end | |
extra_callback("before_expressions", ast, xs, visit_node) | |
if ast.exps then | |
xs[3] = recurse(ast.exps) | |
end | |
end | |
local function walk_var_value(ast, xs) | |
xs[1] = recurse(ast.var) | |
xs[2] = recurse(ast.value) | |
end | |
local function walk_named_function(ast, xs) | |
recurse_typeargs(ast, visit_type) | |
xs[1] = recurse(ast.name) | |
xs[2] = recurse(ast.args) | |
xs[3] = recurse_type(ast.rets, visit_type) | |
extra_callback("before_statements", ast, xs, visit_node) | |
xs[4] = recurse(ast.body) | |
end | |
local walkers = { | |
["op"] = function(ast, xs) | |
xs[1] = recurse(ast.e1) | |
local p1 = ast.e1.op and ast.e1.op.prec or nil | |
if ast.op.op == ":" and ast.e1.kind == "string" then | |
p1 = -999 | |
end | |
xs[2] = p1 | |
if ast.op.arity == 2 then | |
extra_callback("before_e2", ast, xs, visit_node) | |
if ast.op.op == "is" or ast.op.op == "as" then | |
xs[3] = recurse_type(ast.e2.casttype, visit_type) | |
else | |
xs[3] = recurse(ast.e2) | |
end | |
xs[4] = (ast.e2.op and ast.e2.op.prec) | |
end | |
end, | |
["statements"] = walk_children, | |
["argument_list"] = walk_children, | |
["table_literal"] = walk_children, | |
["variable_list"] = walk_children, | |
["expression_list"] = walk_children, | |
["table_item"] = function(ast, xs) | |
xs[1] = recurse(ast.key) | |
xs[2] = recurse(ast.value) | |
if ast.decltype then | |
xs[3] = recurse_type(ast.decltype, visit_type) | |
end | |
end, | |
["assignment"] = walk_vars_exps, | |
["local_declaration"] = walk_vars_exps, | |
["global_declaration"] = walk_vars_exps, | |
["local_type"] = walk_var_value, | |
["global_type"] = walk_var_value, | |
["if"] = function(ast, xs) | |
for _, e in ipairs(ast.if_blocks) do | |
table.insert(xs, recurse(e)) | |
end | |
end, | |
["if_block"] = function(ast, xs) | |
if ast.exp then | |
xs[1] = recurse(ast.exp) | |
end | |
extra_callback("before_statements", ast, xs, visit_node) | |
xs[2] = recurse(ast.body) | |
end, | |
["while"] = function(ast, xs) | |
xs[1] = recurse(ast.exp) | |
extra_callback("before_statements", ast, xs, visit_node) | |
xs[2] = recurse(ast.body) | |
end, | |
["repeat"] = function(ast, xs) | |
xs[1] = recurse(ast.body) | |
xs[2] = recurse(ast.exp) | |
end, | |
["function"] = function(ast, xs) | |
recurse_typeargs(ast, visit_type) | |
xs[1] = recurse(ast.args) | |
xs[2] = recurse_type(ast.rets, visit_type) | |
extra_callback("before_statements", ast, xs, visit_node) | |
xs[3] = recurse(ast.body) | |
end, | |
["local_function"] = walk_named_function, | |
["global_function"] = walk_named_function, | |
["record_function"] = function(ast, xs) | |
recurse_typeargs(ast, visit_type) | |
xs[1] = recurse(ast.fn_owner) | |
xs[2] = recurse(ast.name) | |
xs[3] = recurse(ast.args) | |
xs[4] = recurse_type(ast.rets, visit_type) | |
extra_callback("before_statements", ast, xs, visit_node) | |
xs[5] = recurse(ast.body) | |
end, | |
["forin"] = function(ast, xs) | |
xs[1] = recurse(ast.vars) | |
xs[2] = recurse(ast.exps) | |
extra_callback("before_statements", ast, xs, visit_node) | |
xs[3] = recurse(ast.body) | |
end, | |
["fornum"] = function(ast, xs) | |
xs[1] = recurse(ast.var) | |
xs[2] = recurse(ast.from) | |
xs[3] = recurse(ast.to) | |
xs[4] = ast.step and recurse(ast.step) | |
extra_callback("before_statements", ast, xs, visit_node) | |
xs[5] = recurse(ast.body) | |
end, | |
["return"] = function(ast, xs) | |
xs[1] = recurse(ast.exps) | |
end, | |
["do"] = function(ast, xs) | |
xs[1] = recurse(ast.body) | |
end, | |
["paren"] = function(ast, xs) | |
xs[1] = recurse(ast.e1) | |
end, | |
["newtype"] = function(ast, xs) | |
xs[1] = recurse_type(ast.newtype, visit_type) | |
end, | |
["argument"] = function(ast, xs) | |
if ast.decltype then | |
xs[1] = recurse_type(ast.decltype, visit_type) | |
end | |
end, | |
} | |
if not visit_node.allow_missing_cbs and not visit_node.cbs then | |
error("missing cbs in visit_node") | |
end | |
local visit_after = visit_node.after | |
recurse = function(ast) | |
local xs = {} | |
local kind = assert(ast.kind) | |
local cbs = visit_node.cbs | |
local cbkind = cbs and cbs[kind] | |
do | |
if cbkind then | |
if cbkind.before then | |
cbkind.before(ast) | |
end | |
else | |
if cbs then | |
error("internal compiler error: no visitor for " .. kind) | |
end | |
end | |
end | |
local fn = walkers[kind] | |
if fn then | |
fn(ast, xs) | |
else | |
assert(no_recurse_node[kind]) | |
end | |
local ret | |
do | |
local cbkind_after = cbkind and cbkind.after | |
if cbkind_after then | |
ret = cbkind_after(ast, xs) | |
end | |
if visit_after then | |
ret = visit_after(ast, xs, ret) | |
end | |
end | |
return ret | |
end | |
return recurse(root) | |
end | |
local tight_op = { | |
[1] = { | |
["-"] = true, | |
["~"] = true, | |
["#"] = true, | |
}, | |
[2] = { | |
["."] = true, | |
[":"] = true, | |
}, | |
} | |
local spaced_op = { | |
[1] = { | |
["not"] = true, | |
}, | |
[2] = { | |
["or"] = true, | |
["and"] = true, | |
["<"] = true, | |
[">"] = true, | |
["<="] = true, | |
[">="] = true, | |
["~="] = true, | |
["=="] = true, | |
["|"] = true, | |
["~"] = true, | |
["&"] = true, | |
["<<"] = true, | |
[">>"] = true, | |
[".."] = true, | |
["+"] = true, | |
["-"] = true, | |
["*"] = true, | |
["/"] = true, | |
["//"] = true, | |
["%"] = true, | |
["^"] = true, | |
}, | |
} | |
local PrettyPrintOpts = {} | |
local default_pretty_print_ast_opts = { | |
preserve_indent = true, | |
preserve_newlines = true, | |
} | |
local fast_pretty_print_ast_opts = { | |
preserve_indent = false, | |
preserve_newlines = true, | |
} | |
local primitive = { | |
["function"] = "function", | |
["enum"] = "string", | |
["boolean"] = "boolean", | |
["string"] = "string", | |
["nil"] = "nil", | |
["number"] = "number", | |
["integer"] = "number", | |
["thread"] = "thread", | |
} | |
function tl.pretty_print_ast(ast, mode) | |
local indent = 0 | |
local opts | |
if type(mode) == "table" then | |
opts = mode | |
elseif mode == true then | |
opts = fast_pretty_print_ast_opts | |
else | |
opts = default_pretty_print_ast_opts | |
end | |
local Output = {} | |
local save_indent = {} | |
local function increment_indent(node) | |
local child = node.body or node[1] | |
if not child then | |
return | |
end | |
if child.y ~= node.y then | |
if indent == 0 and #save_indent > 0 then | |
indent = save_indent[#save_indent] + 1 | |
else | |
indent = indent + 1 | |
end | |
else | |
table.insert(save_indent, indent) | |
indent = 0 | |
end | |
end | |
local function decrement_indent(node, child) | |
if child.y ~= node.y then | |
indent = indent - 1 | |
else | |
indent = table.remove(save_indent) | |
end | |
end | |
if not opts.preserve_indent then | |
increment_indent = nil | |
decrement_indent = function() end | |
end | |
local function add_string(out, s) | |
table.insert(out, s) | |
if string.find(s, "\n", 1, true) then | |
for _nl in s:gmatch("\n") do | |
out.h = out.h + 1 | |
end | |
end | |
end | |
local function add_child(out, child, space, current_indent) | |
if #child == 0 then | |
return | |
end | |
if child.y ~= -1 and child.y < out.y then | |
out.y = child.y | |
end | |
if child.y > out.y + out.h and opts.preserve_newlines then | |
local delta = child.y - (out.y + out.h) | |
out.h = out.h + delta | |
table.insert(out, ("\n"):rep(delta)) | |
else | |
if space then | |
if space ~= "" then | |
table.insert(out, space) | |
end | |
current_indent = nil | |
end | |
end | |
if current_indent and opts.preserve_indent then | |
table.insert(out, (" "):rep(current_indent)) | |
end | |
table.insert(out, child) | |
out.h = out.h + child.h | |
end | |
local function concat_output(out) | |
for i, s in ipairs(out) do | |
if type(s) == "table" then | |
out[i] = concat_output(s) | |
end | |
end | |
return table.concat(out) | |
end | |
local function print_record_def(typ) | |
local out = { "{" } | |
for _, name in ipairs(typ.field_order) do | |
if is_typetype(typ.fields[name]) and is_record_type(typ.fields[name].def) then | |
table.insert(out, name) | |
table.insert(out, " = ") | |
table.insert(out, print_record_def(typ.fields[name].def)) | |
table.insert(out, ", ") | |
end | |
end | |
table.insert(out, "}") | |
return table.concat(out) | |
end | |
local visit_node = {} | |
visit_node.cbs = { | |
["statements"] = { | |
after = function(node, children) | |
local out = { y = node.y, h = 0 } | |
local space | |
for i, child in ipairs(children) do | |
add_child(out, child, space, indent) | |
if node[i].semicolon then | |
table.insert(out, ";") | |
space = " " | |
else | |
space = "; " | |
end | |
end | |
return out | |
end, | |
}, | |
["local_declaration"] = { | |
after = function(node, children) | |
local out = { y = node.y, h = 0 } | |
table.insert(out, "local") | |
add_child(out, children[1], " ") | |
if children[3] then | |
table.insert(out, " =") | |
add_child(out, children[3], " ") | |
end | |
return out | |
end, | |
}, | |
["local_type"] = { | |
after = function(node, children) | |
local out = { y = node.y, h = 0 } | |
table.insert(out, "local") | |
add_child(out, children[1], " ") | |
table.insert(out, " =") | |
add_child(out, children[2], " ") | |
return out | |
end, | |
}, | |
["global_type"] = { | |
after = function(node, children) | |
local out = { y = node.y, h = 0 } | |
add_child(out, children[1], " ") | |
table.insert(out, " =") | |
add_child(out, children[2], " ") | |
return out | |
end, | |
}, | |
["global_declaration"] = { | |
after = function(node, children) | |
local out = { y = node.y, h = 0 } | |
if children[3] then | |
add_child(out, children[1]) | |
table.insert(out, " =") | |
add_child(out, children[3], " ") | |
end | |
return out | |
end, | |
}, | |
["assignment"] = { | |
after = function(node, children) | |
local out = { y = node.y, h = 0 } | |
add_child(out, children[1]) | |
table.insert(out, " =") | |
add_child(out, children[3], " ") | |
return out | |
end, | |
}, | |
["if"] = { | |
after = function(node, children) | |
local out = { y = node.y, h = 0 } | |
for i, child in ipairs(children) do | |
add_child(out, child, i > 1 and " ", child.y ~= node.y and indent) | |
end | |
add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent) | |
return out | |
end, | |
}, | |
["if_block"] = { | |
before = increment_indent, | |
after = function(node, children) | |
local out = { y = node.y, h = 0 } | |
if node.if_block_n == 1 then | |
table.insert(out, "if") | |
elseif not node.exp then | |
table.insert(out, "else") | |
else | |
table.insert(out, "elseif") | |
end | |
if node.exp then | |
add_child(out, children[1], " ") | |
table.insert(out, " then") | |
end | |
add_child(out, children[2], " ") | |
decrement_indent(node, node.body) | |
return out | |
end, | |
}, | |
["while"] = { | |
before = increment_indent, | |
after = function(node, children) | |
local out = { y = node.y, h = 0 } | |
table.insert(out, "while") | |
add_child(out, children[1], " ") | |
table.insert(out, " do") | |
add_child(out, children[2], " ") | |
decrement_indent(node, node.body) | |
add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent) | |
return out | |
end, | |
}, | |
["repeat"] = { | |
before = increment_indent, | |
after = function(node, children) | |
local out = { y = node.y, h = 0 } | |
table.insert(out, "repeat") | |
add_child(out, children[1], " ") | |
decrement_indent(node, node.body) | |
add_child(out, { y = node.yend, h = 0, [1] = "until " }, " ", indent) | |
add_child(out, children[2]) | |
return out | |
end, | |
}, | |
["do"] = { | |
before = increment_indent, | |
after = function(node, children) | |
local out = { y = node.y, h = 0 } | |
table.insert(out, "do") | |
add_child(out, children[1], " ") | |
decrement_indent(node, node.body) | |
add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent) | |
return out | |
end, | |
}, | |
["forin"] = { | |
before = increment_indent, | |
after = function(node, children) | |
local out = { y = node.y, h = 0 } | |
table.insert(out, "for") | |
add_child(out, children[1], " ") | |
table.insert(out, " in") | |
add_child(out, children[2], " ") | |
table.insert(out, " do") | |
add_child(out, children[3], " ") | |
decrement_indent(node, node.body) | |
add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent) | |
return out | |
end, | |
}, | |
["fornum"] = { | |
before = increment_indent, | |
after = function(node, children) | |
local out = { y = node.y, h = 0 } | |
table.insert(out, "for") | |
add_child(out, children[1], " ") | |
table.insert(out, " =") | |
add_child(out, children[2], " ") | |
table.insert(out, ",") | |
add_child(out, children[3], " ") | |
if children[4] then | |
table.insert(out, ",") | |
add_child(out, children[4], " ") | |
end | |
table.insert(out, " do") | |
add_child(out, children[5], " ") | |
decrement_indent(node, node.body) | |
add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent) | |
return out | |
end, | |
}, | |
["return"] = { | |
after = function(node, children) | |
local out = { y = node.y, h = 0 } | |
table.insert(out, "return") | |
if #children[1] > 0 then | |
add_child(out, children[1], " ") | |
end | |
return out | |
end, | |
}, | |
["break"] = { | |
after = function(node, _children) | |
local out = { y = node.y, h = 0 } | |
table.insert(out, "break") | |
return out | |
end, | |
}, | |
["variable_list"] = { | |
after = function(node, children) | |
local out = { y = node.y, h = 0 } | |
local space | |
for i, child in ipairs(children) do | |
if i > 1 then | |
table.insert(out, ",") | |
space = " " | |
end | |
add_child(out, child, space, child.y ~= node.y and indent) | |
end | |
return out | |
end, | |
}, | |
["table_literal"] = { | |
before = increment_indent, | |
after = function(node, children) | |
local out = { y = node.y, h = 0 } | |
if #children == 0 then | |
table.insert(out, "{}") | |
return out | |
end | |
table.insert(out, "{") | |
local n = #children | |
for i, child in ipairs(children) do | |
add_child(out, child, " ", child.y ~= node.y and indent) | |
if i < n or node.yend ~= node.y then | |
table.insert(out, ",") | |
end | |
end | |
decrement_indent(node, node[1]) | |
add_child(out, { y = node.yend, h = 0, [1] = "}" }, " ", indent) | |
return out | |
end, | |
}, | |
["table_item"] = { | |
after = function(node, children) | |
local out = { y = node.y, h = 0 } | |
if node.key_parsed ~= "implicit" then | |
if node.key_parsed == "short" then | |
children[1][1] = children[1][1]:sub(2, -2) | |
add_child(out, children[1]) | |
table.insert(out, " = ") | |
else | |
table.insert(out, "[") | |
if node.key_parsed == "long" and node.key.is_longstring then | |
table.insert(children[1], 1, " ") | |
table.insert(children[1], " ") | |
end | |
add_child(out, children[1]) | |
table.insert(out, "] = ") | |
end | |
end | |
add_child(out, children[2]) | |
return out | |
end, | |
}, | |
["local_function"] = { | |
before = increment_indent, | |
after = function(node, children) | |
local out = { y = node.y, h = 0 } | |
table.insert(out, "local function") | |
add_child(out, children[1], " ") | |
table.insert(out, "(") | |
add_child(out, children[2]) | |
table.insert(out, ")") | |
add_child(out, children[4], " ") | |
decrement_indent(node, node.body) | |
add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent) | |
return out | |
end, | |
}, | |
["global_function"] = { | |
before = increment_indent, | |
after = function(node, children) | |
local out = { y = node.y, h = 0 } | |
table.insert(out, "function") | |
add_child(out, children[1], " ") | |
table.insert(out, "(") | |
add_child(out, children[2]) | |
table.insert(out, ")") | |
add_child(out, children[4], " ") | |
decrement_indent(node, node.body) | |
add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent) | |
return out | |
end, | |
}, | |
["record_function"] = { | |
before = increment_indent, | |
after = function(node, children) | |
local out = { y = node.y, h = 0 } | |
table.insert(out, "function") | |
add_child(out, children[1], " ") | |
table.insert(out, node.is_method and ":" or ".") | |
add_child(out, children[2]) | |
table.insert(out, "(") | |
if node.is_method then | |
table.remove(children[3], 1) | |
if children[3][1] == "," then | |
table.remove(children[3], 1) | |
table.remove(children[3], 1) | |
end | |
end | |
add_child(out, children[3]) | |
table.insert(out, ")") | |
add_child(out, children[5], " ") | |
decrement_indent(node, node.body) | |
add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent) | |
return out | |
end, | |
}, | |
["function"] = { | |
before = increment_indent, | |
after = function(node, children) | |
local out = { y = node.y, h = 0 } | |
table.insert(out, "function(") | |
add_child(out, children[1]) | |
table.insert(out, ")") | |
add_child(out, children[3], " ") | |
decrement_indent(node, node.body) | |
add_child(out, { y = node.yend, h = 0, [1] = "end" }, " ", indent) | |
return out | |
end, | |
}, | |
["cast"] = {}, | |
["paren"] = { | |
after = function(node, children) | |
local out = { y = node.y, h = 0 } | |
table.insert(out, "(") | |
add_child(out, children[1], "", indent) | |
table.insert(out, ")") | |
return out | |
end, | |
}, | |
["op"] = { | |
after = function(node, children) | |
local out = { y = node.y, h = 0 } | |
if node.op.op == "@funcall" then | |
add_child(out, children[1], "", indent) | |
table.insert(out, "(") | |
add_child(out, children[3], "", indent) | |
table.insert(out, ")") | |
elseif node.op.op == "@index" then | |
add_child(out, children[1], "", indent) | |
table.insert(out, "[") | |
if node.e2.is_longstring then | |
table.insert(children[3], 1, " ") | |
table.insert(children[3], " ") | |
end | |
add_child(out, children[3], "", indent) | |
table.insert(out, "]") | |
elseif node.op.op == "as" then | |
add_child(out, children[1], "", indent) | |
elseif node.op.op == "is" then | |
if node.e2.casttype.typename == "integer" then | |
table.insert(out, "math.type(") | |
add_child(out, children[1], "", indent) | |
table.insert(out, ") == \"integer\"") | |
else | |
table.insert(out, "type(") | |
add_child(out, children[1], "", indent) | |
table.insert(out, ") == \"") | |
add_child(out, children[3], "", indent) | |
table.insert(out, "\"") | |
end | |
elseif spaced_op[node.op.arity][node.op.op] or tight_op[node.op.arity][node.op.op] then | |
local space = spaced_op[node.op.arity][node.op.op] and " " or "" | |
if children[2] and node.op.prec > tonumber(children[2]) then | |
table.insert(children[1], 1, "(") | |
table.insert(children[1], ")") | |
end | |
if node.op.arity == 1 then | |
table.insert(out, node.op.op) | |
add_child(out, children[1], space, indent) | |
elseif node.op.arity == 2 then | |
add_child(out, children[1], "", indent) | |
if space == " " then | |
table.insert(out, " ") | |
end | |
table.insert(out, node.op.op) | |
if children[4] and node.op.prec > tonumber(children[4]) then | |
table.insert(children[3], 1, "(") | |
table.insert(children[3], ")") | |
end | |
add_child(out, children[3], space, indent) | |
end | |
else | |
error("unknown node op " .. node.op.op) | |
end | |
return out | |
end, | |
}, | |
["variable"] = { | |
after = function(node, _children) | |
local out = { y = node.y, h = 0 } | |
add_string(out, node.tk) | |
return out | |
end, | |
}, | |
["newtype"] = { | |
after = function(node, _children) | |
local out = { y = node.y, h = 0 } | |
if node.is_alias then | |
table.insert(out, table.concat(node.newtype.def.names, ".")) | |
elseif is_record_type(node.newtype.def) then | |
table.insert(out, print_record_def(node.newtype.def)) | |
else | |
table.insert(out, "{}") | |
end | |
return out | |
end, | |
}, | |
["goto"] = { | |
after = function(node, _children) | |
local out = { y = node.y, h = 0 } | |
table.insert(out, "goto ") | |
table.insert(out, node.label) | |
return out | |
end, | |
}, | |
["label"] = { | |
after = function(node, _children) | |
local out = { y = node.y, h = 0 } | |
table.insert(out, "::") | |
table.insert(out, node.label) | |
table.insert(out, "::") | |
return out | |
end, | |
}, | |
} | |
local visit_type = {} | |
visit_type.cbs = { | |
["string"] = { | |
after = function(typ, _children) | |
local out = { y = typ.y or -1, h = 0 } | |
local r = typ.resolved or typ | |
local lua_type = primitive[r.typename] or | |
(r.is_userdata and "userdata") or | |
"table" | |
table.insert(out, lua_type) | |
return out | |
end, | |
}, | |
} | |
visit_type.cbs["typetype"] = visit_type.cbs["string"] | |
visit_type.cbs["typevar"] = visit_type.cbs["string"] | |
visit_type.cbs["typearg"] = visit_type.cbs["string"] | |
visit_type.cbs["function"] = visit_type.cbs["string"] | |
visit_type.cbs["thread"] = visit_type.cbs["string"] | |
visit_type.cbs["array"] = visit_type.cbs["string"] | |
visit_type.cbs["map"] = visit_type.cbs["string"] | |
visit_type.cbs["tupletable"] = visit_type.cbs["string"] | |
visit_type.cbs["arrayrecord"] = visit_type.cbs["string"] | |
visit_type.cbs["record"] = visit_type.cbs["string"] | |
visit_type.cbs["enum"] = visit_type.cbs["string"] | |
visit_type.cbs["boolean"] = visit_type.cbs["string"] | |
visit_type.cbs["nil"] = visit_type.cbs["string"] | |
visit_type.cbs["number"] = visit_type.cbs["string"] | |
visit_type.cbs["integer"] = visit_type.cbs["string"] | |
visit_type.cbs["union"] = visit_type.cbs["string"] | |
visit_type.cbs["nominal"] = visit_type.cbs["string"] | |
visit_type.cbs["bad_nominal"] = visit_type.cbs["string"] | |
visit_type.cbs["emptytable"] = visit_type.cbs["string"] | |
visit_type.cbs["table_item"] = visit_type.cbs["string"] | |
visit_type.cbs["unresolved_emptytable_value"] = visit_type.cbs["string"] | |
visit_type.cbs["tuple"] = visit_type.cbs["string"] | |
visit_type.cbs["poly"] = visit_type.cbs["string"] | |
visit_type.cbs["any"] = visit_type.cbs["string"] | |
visit_type.cbs["unknown"] = visit_type.cbs["string"] | |
visit_type.cbs["invalid"] = visit_type.cbs["string"] | |
visit_type.cbs["unresolved"] = visit_type.cbs["string"] | |
visit_type.cbs["none"] = visit_type.cbs["string"] | |
visit_node.cbs["expression_list"] = visit_node.cbs["variable_list"] | |
visit_node.cbs["argument_list"] = visit_node.cbs["variable_list"] | |
visit_node.cbs["identifier"] = visit_node.cbs["variable"] | |
visit_node.cbs["number"] = visit_node.cbs["variable"] | |
visit_node.cbs["integer"] = visit_node.cbs["variable"] | |
visit_node.cbs["string"] = visit_node.cbs["variable"] | |
visit_node.cbs["nil"] = visit_node.cbs["variable"] | |
visit_node.cbs["boolean"] = visit_node.cbs["variable"] | |
visit_node.cbs["..."] = visit_node.cbs["variable"] | |
visit_node.cbs["argument"] = visit_node.cbs["variable"] | |
visit_node.cbs["type_identifier"] = visit_node.cbs["variable"] | |
local out = recurse_node(ast, visit_node, visit_type) | |
local code | |
if opts.preserve_newlines then | |
code = { y = 1, h = 0 } | |
add_child(code, out) | |
else | |
code = out | |
end | |
return concat_output(code) | |
end | |
local function VARARG(t) | |
local tuple = t | |
tuple.typename = "tuple" | |
tuple.is_va = true | |
return a_type(t) | |
end | |
local function TUPLE(t) | |
local tuple = t | |
tuple.typename = "tuple" | |
return a_type(t) | |
end | |
local function UNION(t) | |
return a_type({ typename = "union", types = t }) | |
end | |
local NONE = a_type({ typename = "none" }) | |
local INVALID = a_type({ typename = "invalid" }) | |
local UNKNOWN = a_type({ typename = "unknown" }) | |
local ALPHA = a_type({ typename = "typevar", typevar = "@a" }) | |
local BETA = a_type({ typename = "typevar", typevar = "@b" }) | |
local ARG_ALPHA = a_type({ typename = "typearg", typearg = "@a" }) | |
local ARG_BETA = a_type({ typename = "typearg", typearg = "@b" }) | |
local ARRAY_OF_ALPHA = a_type({ typename = "array", elements = ALPHA }) | |
local MAP_OF_ALPHA_TO_BETA = a_type({ typename = "map", keys = ALPHA, values = BETA }) | |
local NOMINAL_METATABLE_OF_ALPHA = a_type({ typename = "nominal", names = { "metatable" }, typevals = { ALPHA } }) | |
local ARRAY_OF_STRING = a_type({ typename = "array", elements = STRING }) | |
local FUNCTION = a_type({ typename = "function", args = VARARG({ ANY }), rets = VARARG({ ANY }) }) | |
local NOMINAL_FILE = a_type({ typename = "nominal", names = { "FILE" } }) | |
local XPCALL_MSGH_FUNCTION = a_type({ typename = "function", args = TUPLE({ ANY }), rets = TUPLE({}) }) | |
local USERDATA = ANY | |
local numeric_binop = { | |
["number"] = { | |
["number"] = NUMBER, | |
["integer"] = NUMBER, | |
}, | |
["integer"] = { | |
["integer"] = INTEGER, | |
["number"] = NUMBER, | |
}, | |
} | |
local float_binop = { | |
["number"] = { | |
["number"] = NUMBER, | |
["integer"] = NUMBER, | |
}, | |
["integer"] = { | |
["integer"] = NUMBER, | |
["number"] = NUMBER, | |
}, | |
} | |
local integer_binop = { | |
["number"] = { | |
["number"] = INTEGER, | |
["integer"] = INTEGER, | |
}, | |
["integer"] = { | |
["integer"] = INTEGER, | |
["number"] = INTEGER, | |
}, | |
} | |
local relational_binop = { | |
["number"] = { | |
["integer"] = BOOLEAN, | |
["number"] = BOOLEAN, | |
}, | |
["integer"] = { | |
["number"] = BOOLEAN, | |
["integer"] = BOOLEAN, | |
}, | |
["string"] = { | |
["string"] = BOOLEAN, | |
}, | |
["boolean"] = { | |
["boolean"] = BOOLEAN, | |
}, | |
} | |
local equality_binop = { | |
["number"] = { | |
["number"] = BOOLEAN, | |
["integer"] = BOOLEAN, | |
["nil"] = BOOLEAN, | |
}, | |
["integer"] = { | |
["number"] = BOOLEAN, | |
["integer"] = BOOLEAN, | |
["nil"] = BOOLEAN, | |
}, | |
["string"] = { | |
["string"] = BOOLEAN, | |
["nil"] = BOOLEAN, | |
}, | |
["boolean"] = { | |
["boolean"] = BOOLEAN, | |
["nil"] = BOOLEAN, | |
}, | |
["record"] = { | |
["emptytable"] = BOOLEAN, | |
["arrayrecord"] = BOOLEAN, | |
["record"] = BOOLEAN, | |
["nil"] = BOOLEAN, | |
}, | |
["array"] = { | |
["emptytable"] = BOOLEAN, | |
["arrayrecord"] = BOOLEAN, | |
["array"] = BOOLEAN, | |
["nil"] = BOOLEAN, | |
}, | |
["arrayrecord"] = { | |
["emptytable"] = BOOLEAN, | |
["arrayrecord"] = BOOLEAN, | |
["record"] = BOOLEAN, | |
["array"] = BOOLEAN, | |
["nil"] = BOOLEAN, | |
}, | |
["map"] = { | |
["emptytable"] = BOOLEAN, | |
["map"] = BOOLEAN, | |
["nil"] = BOOLEAN, | |
}, | |
["thread"] = { | |
["thread"] = BOOLEAN, | |
["nil"] = BOOLEAN, | |
}, | |
} | |
local unop_types = { | |
["#"] = { | |
["arrayrecord"] = INTEGER, | |
["string"] = INTEGER, | |
["array"] = INTEGER, | |
["tupletable"] = INTEGER, | |
["map"] = INTEGER, | |
["emptytable"] = INTEGER, | |
}, | |
["-"] = { | |
["number"] = NUMBER, | |
["integer"] = INTEGER, | |
}, | |
["~"] = { | |
["number"] = INTEGER, | |
["integer"] = INTEGER, | |
}, | |
["not"] = { | |
["string"] = BOOLEAN, | |
["number"] = BOOLEAN, | |
["integer"] = BOOLEAN, | |
["boolean"] = BOOLEAN, | |
["record"] = BOOLEAN, | |
["arrayrecord"] = BOOLEAN, | |
["array"] = BOOLEAN, | |
["tupletable"] = BOOLEAN, | |
["map"] = BOOLEAN, | |
["emptytable"] = BOOLEAN, | |
["thread"] = BOOLEAN, | |
}, | |
} | |
local unop_to_metamethod = { | |
["#"] = "__len", | |
["-"] = "__unm", | |
["~"] = "__bnot", | |
} | |
local binop_types = { | |
["+"] = numeric_binop, | |
["-"] = numeric_binop, | |
["*"] = numeric_binop, | |
["%"] = numeric_binop, | |
["/"] = float_binop, | |
["//"] = numeric_binop, | |
["^"] = float_binop, | |
["&"] = integer_binop, | |
["|"] = integer_binop, | |
["<<"] = integer_binop, | |
[">>"] = integer_binop, | |
["~"] = integer_binop, | |
["=="] = equality_binop, | |
["~="] = equality_binop, | |
["<="] = relational_binop, | |
[">="] = relational_binop, | |
["<"] = relational_binop, | |
[">"] = relational_binop, | |
["or"] = { | |
["boolean"] = { | |
["boolean"] = BOOLEAN, | |
["function"] = FUNCTION, | |
}, | |
["number"] = { | |
["integer"] = NUMBER, | |
["number"] = NUMBER, | |
["boolean"] = BOOLEAN, | |
}, | |
["integer"] = { | |
["integer"] = INTEGER, | |
["number"] = NUMBER, | |
["boolean"] = BOOLEAN, | |
}, | |
["string"] = { | |
["string"] = STRING, | |
["boolean"] = BOOLEAN, | |
["enum"] = STRING, | |
}, | |
["function"] = { | |
["boolean"] = BOOLEAN, | |
}, | |
["array"] = { | |
["boolean"] = BOOLEAN, | |
}, | |
["record"] = { | |
["boolean"] = BOOLEAN, | |
}, | |
["arrayrecord"] = { | |
["boolean"] = BOOLEAN, | |
}, | |
["map"] = { | |
["boolean"] = BOOLEAN, | |
}, | |
["enum"] = { | |
["string"] = STRING, | |
}, | |
["thread"] = { | |
["boolean"] = BOOLEAN, | |
}, | |
}, | |
[".."] = { | |
["string"] = { | |
["string"] = STRING, | |
["enum"] = STRING, | |
["number"] = STRING, | |
["integer"] = STRING, | |
}, | |
["number"] = { | |
["integer"] = STRING, | |
["number"] = STRING, | |
["string"] = STRING, | |
["enum"] = STRING, | |
}, | |
["integer"] = { | |
["integer"] = STRING, | |
["number"] = STRING, | |
["string"] = STRING, | |
["enum"] = STRING, | |
}, | |
["enum"] = { | |
["number"] = STRING, | |
["integer"] = STRING, | |
["string"] = STRING, | |
["enum"] = STRING, | |
}, | |
}, | |
} | |
local binop_to_metamethod = { | |
["+"] = "__add", | |
["-"] = "__sub", | |
["*"] = "__mul", | |
["/"] = "__div", | |
["%"] = "__mod", | |
["^"] = "__pow", | |
["//"] = "__idiv", | |
["&"] = "__band", | |
["|"] = "__bor", | |
["~"] = "__bxor", | |
["<<"] = "__shl", | |
[">>"] = "__shr", | |
[".."] = "__concat", | |
["=="] = "__eq", | |
["<"] = "__lt", | |
["<="] = "__le", | |
} | |
local function is_unknown(t) | |
return t.typename == "unknown" or | |
t.typename == "unresolved_emptytable_value" | |
end | |
local show_type | |
local function show_type_base(t, short, seen) | |
if seen[t] then | |
return seen[t] | |
end | |
seen[t] = "..." | |
local function show(typ) | |
return show_type(typ, short, seen) | |
end | |
if t.typename == "nominal" then | |
if t.typevals then | |
local out = { table.concat(t.names, "."), "<" } | |
local vals = {} | |
for _, v in ipairs(t.typevals) do | |
table.insert(vals, show(v)) | |
end | |
table.insert(out, table.concat(vals, ", ")) | |
table.insert(out, ">") | |
return table.concat(out) | |
else | |
return table.concat(t.names, ".") | |
end | |
elseif t.typename == "tuple" then | |
local out = {} | |
for _, v in ipairs(t) do | |
table.insert(out, show(v)) | |
end | |
return "(" .. table.concat(out, ", ") .. ")" | |
elseif t.typename == "tupletable" then | |
local out = {} | |
for _, v in ipairs(t.types) do | |
table.insert(out, show(v)) | |
end | |
return "{" .. table.concat(out, ", ") .. "}" | |
elseif t.typename == "poly" then | |
local out = {} | |
for _, v in ipairs(t.types) do | |
table.insert(out, show(v)) | |
end | |
return table.concat(out, " and ") | |
elseif t.typename == "union" then | |
local out = {} | |
for _, v in ipairs(t.types) do | |
table.insert(out, show(v)) | |
end | |
return table.concat(out, " | ") | |
elseif t.typename == "emptytable" then | |
return "{}" | |
elseif t.typename == "map" then | |
return "{" .. show(t.keys) .. " : " .. show(t.values) .. "}" | |
elseif t.typename == "array" then | |
return "{" .. show(t.elements) .. "}" | |
elseif t.typename == "enum" then | |
return t.names and table.concat(t.names, ".") or "enum" | |
elseif is_record_type(t) then | |
if short then | |
return "record" | |
else | |
local out = { "record" } | |
if t.typeargs then | |
table.insert(out, "<") | |
local typeargs = {} | |
for _, v in ipairs(t.typeargs) do | |
table.insert(typeargs, show(v)) | |
end | |
table.insert(out, table.concat(typeargs, ", ")) | |
table.insert(out, ">") | |
end | |
table.insert(out, " (") | |
if t.elements then | |
table.insert(out, "{" .. show(t.elements) .. "}") | |
end | |
local fs = {} | |
for _, k in ipairs(t.field_order) do | |
local v = t.fields[k] | |
table.insert(fs, k .. ": " .. show(v)) | |
end | |
table.insert(out, table.concat(fs, "; ")) | |
table.insert(out, ")") | |
return table.concat(out) | |
end | |
elseif t.typename == "function" then | |
local out = { "function" } | |
if t.typeargs then | |
table.insert(out, "<") | |
local typeargs = {} | |
for _, v in ipairs(t.typeargs) do | |
table.insert(typeargs, show(v)) | |
end | |
table.insert(out, table.concat(typeargs, ", ")) | |
table.insert(out, ">") | |
end | |
table.insert(out, "(") | |
local args = {} | |
if t.is_method then | |
table.insert(args, "self") | |
end | |
for i, v in ipairs(t.args) do | |
if not t.is_method or i > 1 then | |
table.insert(args, (i == #t.args and t.args.is_va and "...: " or "") .. show(v)) | |
end | |
end | |
table.insert(out, table.concat(args, ", ")) | |
table.insert(out, ")") | |
if #t.rets > 0 then | |
table.insert(out, ": ") | |
local rets = {} | |
for i, v in ipairs(t.rets) do | |
table.insert(rets, show(v) .. (i == #t.rets and t.rets.is_va and "..." or "")) | |
end | |
table.insert(out, table.concat(rets, ", ")) | |
end | |
return table.concat(out) | |
elseif t.typename == "number" or | |
t.typename == "integer" or | |
t.typename == "boolean" or | |
t.typename == "thread" then | |
return t.typename | |
elseif t.typename == "string" then | |
if short then | |
return "string" | |
else | |
return t.typename .. | |
(t.tk and " " .. t.tk or "") | |
end | |
elseif t.typename == "typevar" then | |
return t.typevar | |
elseif t.typename == "typearg" then | |
return t.typearg | |
elseif is_unknown(t) then | |
return "<unknown type>" | |
elseif t.typename == "invalid" then | |
return "<invalid type>" | |
elseif t.typename == "any" then | |
return "<any type>" | |
elseif t.typename == "nil" then | |
return "nil" | |
elseif t.typename == "none" then | |
return "" | |
elseif is_typetype(t) then | |
return "type " .. show(t.def) | |
elseif t.typename == "bad_nominal" then | |
return table.concat(t.names, ".") .. " (an unknown type)" | |
else | |
return tostring(t) | |
end | |
end | |
local function inferred_msg(t) | |
return " (inferred at " .. t.inferred_at_file .. ":" .. t.inferred_at.y .. ":" .. t.inferred_at.x .. ")" | |
end | |
show_type = function(t, short, seen) | |
seen = seen or {} | |
local ret = show_type_base(t, short, seen) | |
if t.inferred_at then | |
ret = ret .. inferred_msg(t) | |
end | |
seen[t] = ret | |
return ret | |
end | |
local function search_for(module_name, suffix, path, tried) | |
for entry in path:gmatch("[^;]+") do | |
local slash_name = module_name:gsub("%.", "/") | |
local filename = entry:gsub("?", slash_name) | |
local tl_filename = filename:gsub("%.lua$", suffix) | |
local fd = io.open(tl_filename, "r") | |
if fd then | |
return tl_filename, fd, tried | |
end | |
table.insert(tried, "no file '" .. tl_filename .. "'") | |
end | |
return nil, nil, tried | |
end | |
function tl.search_module(module_name, search_dtl) | |
local found | |
local fd | |
local tried = {} | |
local path = os.getenv("TL_PATH") or package.path | |
if search_dtl then | |
found, fd, tried = search_for(module_name, ".d.tl", path, tried) | |
if found then | |
return found, fd | |
end | |
end | |
found, fd, tried = search_for(module_name, ".tl", path, tried) | |
if found then | |
return found, fd | |
end | |
found, fd, tried = search_for(module_name, ".lua", path, tried) | |
if found then | |
return found, fd | |
end | |
return nil, nil, tried | |
end | |
local Variable = {} | |
local function sorted_keys(m) | |
local keys = {} | |
for k, _ in pairs(m) do | |
table.insert(keys, k) | |
end | |
table.sort(keys) | |
return keys | |
end | |
local function fill_field_order(t) | |
if t.typename == "record" then | |
t.field_order = sorted_keys(t.fields) | |
end | |
end | |
local function require_module(module_name, lax, env) | |
local modules = env.modules | |
if modules[module_name] then | |
return modules[module_name], true | |
end | |
modules[module_name] = INVALID | |
local found, fd = tl.search_module(module_name, true) | |
if found and (lax or found:match("tl$")) then | |
fd:close() | |
local found_result, err = tl.process(found, env) | |
assert(found_result, err) | |
if not found_result.type then | |
found_result.type = BOOLEAN | |
end | |
env.modules[module_name] = found_result.type | |
return found_result.type, true | |
end | |
return INVALID, found ~= nil | |
end | |
local compat_code_cache = {} | |
local function add_compat_entries(program, used_set, gen_compat) | |
if gen_compat == "off" or not next(used_set) then | |
return | |
end | |
local used_list = sorted_keys(used_set) | |
local compat_loaded = false | |
local n = 1 | |
local function load_code(name, text) | |
local code = compat_code_cache[name] | |
if not code then | |
local tokens = tl.lex(text) | |
local _ | |
_, code = tl.parse_program(tokens, {}, "@internal") | |
tl.type_check(code, { filename = "<internal>", lax = false, gen_compat = "off" }) | |
code = code | |
compat_code_cache[name] = code | |
end | |
for _, c in ipairs(code) do | |
table.insert(program, n, c) | |
n = n + 1 | |
end | |
end | |
local function req(m) | |
return (gen_compat == "optional") and | |
"pcall(require, '" .. m .. "')" or | |
"true, require('" .. m .. "')" | |
end | |
for _, name in ipairs(used_list) do | |
if name == "table.unpack" then | |
load_code(name, "local _tl_table_unpack = unpack or table.unpack") | |
elseif name == "bit32" then | |
load_code(name, "local bit32 = bit32; if not bit32 then local p, m = " .. req("bit32") .. "; if p then bit32 = m end") | |
elseif name == "mt" then | |
load_code(name, "local _tl_mt = function(m, s, a, b) return (getmetatable(s == 1 and a or b)[m](a, b) end") | |
else | |
if not compat_loaded then | |
load_code("compat", "local _tl_compat; if (tonumber((_VERSION or ''):match('[%d.]*$')) or 0) < 5.3 then local p, m = " .. req("compat53.module") .. "; if p then _tl_compat = m end") | |
compat_loaded = true | |
end | |
load_code(name, (("local $NAME = _tl_compat and _tl_compat.$NAME or $NAME"):gsub("$NAME", name))) | |
end | |
end | |
program.y = 1 | |
end | |
local function get_stdlib_compat(lax) | |
if lax then | |
return { | |
["utf8"] = true, | |
} | |
else | |
return { | |
["io"] = true, | |
["math"] = true, | |
["string"] = true, | |
["table"] = true, | |
["utf8"] = true, | |
["coroutine"] = true, | |
["os"] = true, | |
["package"] = true, | |
["debug"] = true, | |
["load"] = true, | |
["loadfile"] = true, | |
["assert"] = true, | |
["pairs"] = true, | |
["ipairs"] = true, | |
["pcall"] = true, | |
["xpcall"] = true, | |
["rawlen"] = true, | |
} | |
end | |
end | |
local bit_operators = { | |
["&"] = "band", | |
["|"] = "bor", | |
["~"] = "bxor", | |
[">>"] = "rshift", | |
["<<"] = "lshift", | |
} | |
local function convert_node_to_compat_call(node, mod_name, fn_name, e1, e2) | |
node.op.op = "@funcall" | |
node.op.arity = 2 | |
node.op.prec = 100 | |
node.e1 = { y = node.y, x = node.x, kind = "op", op = an_operator(node, 2, ".") } | |
node.e1.e1 = { y = node.y, x = node.x, kind = "identifier", tk = mod_name } | |
node.e1.e2 = { y = node.y, x = node.x, kind = "identifier", tk = fn_name } | |
node.e2 = { y = node.y, x = node.x, kind = "expression_list" } | |
node.e2[1] = e1 | |
node.e2[2] = e2 | |
end | |
local function convert_node_to_compat_mt_call(node, mt_name, which_self, e1, e2) | |
node.op.op = "@funcall" | |
node.op.arity = 2 | |
node.op.prec = 100 | |
node.e1 = { y = node.y, x = node.x, kind = "identifier", tk = "_tl_mt" } | |
node.e2 = { y = node.y, x = node.x, kind = "expression_list" } | |
node.e2[1] = { y = node.y, x = node.x, kind = "string", tk = "\"" .. mt_name .. "\"" } | |
node.e2[2] = { y = node.y, x = node.x, kind = "integer", tk = tostring(which_self) } | |
node.e2[3] = e1 | |
node.e2[4] = e2 | |
end | |
local globals_typeid | |
local function init_globals(lax) | |
local globals = {} | |
local stdlib_compat = get_stdlib_compat(lax) | |
local is_first_init = globals_typeid == nil | |
local save_typeid = last_typeid | |
if is_first_init then | |
globals_typeid = last_typeid | |
else | |
last_typeid = globals_typeid | |
end | |
local LOAD_FUNCTION = a_type({ typename = "function", args = {}, rets = TUPLE({ STRING }) }) | |
local OS_DATE_TABLE = a_type({ | |
typename = "record", | |
fields = { | |
["year"] = INTEGER, | |
["month"] = INTEGER, | |
["day"] = INTEGER, | |
["hour"] = INTEGER, | |
["min"] = INTEGER, | |
["sec"] = INTEGER, | |
["wday"] = INTEGER, | |
["yday"] = INTEGER, | |
["isdst"] = BOOLEAN, | |
}, | |
}) | |
local OS_DATE_TABLE_FORMAT = a_type({ typename = "enum", enumset = { ["!*t"] = true, ["*t"] = true } }) | |
local DEBUG_GETINFO_TABLE = a_type({ | |
typename = "record", | |
fields = { | |
["name"] = STRING, | |
["namewhat"] = STRING, | |
["source"] = STRING, | |
["short_src"] = STRING, | |
["linedefined"] = INTEGER, | |
["lastlinedefined"] = INTEGER, | |
["what"] = STRING, | |
["currentline"] = INTEGER, | |
["istailcall"] = BOOLEAN, | |
["nups"] = INTEGER, | |
["nparams"] = INTEGER, | |
["isvararg"] = BOOLEAN, | |
["func"] = ANY, | |
["activelines"] = a_type({ typename = "map", keys = INTEGER, values = BOOLEAN }), | |
}, | |
}) | |
local DEBUG_HOOK_EVENT = a_type({ | |
typename = "enum", | |
enumset = { | |
["call"] = true, | |
["tail call"] = true, | |
["return"] = true, | |
["line"] = true, | |
["count"] = true, | |
}, | |
}) | |
local DEBUG_HOOK_FUNCTION = a_type({ | |
typename = "function", | |
args = TUPLE({ DEBUG_HOOK_EVENT, INTEGER }), | |
rets = TUPLE({}), | |
}) | |
local TABLE_SORT_FUNCTION = a_type({ typename = "function", typeargs = TUPLE({ ARG_ALPHA }), args = TUPLE({ ALPHA, ALPHA }), rets = TUPLE({ BOOLEAN }) }) | |
local OPT_NUMBER = NUMBER | |
local OPT_STRING = STRING | |
local OPT_THREAD = THREAD | |
local OPT_ALPHA = ALPHA | |
local OPT_BETA = BETA | |
local OPT_TABLE = TABLE | |
local OPT_UNION = UNION | |
local OPT_BOOLEAN = BOOLEAN | |
local OPT_NOMINAL_FILE = NOMINAL_FILE | |
local OPT_TABLE_SORT_FUNCTION = TABLE_SORT_FUNCTION | |
local standard_library = { | |
["..."] = VARARG({ STRING }), | |
["any"] = a_type({ typename = "typetype", def = ANY }), | |
["arg"] = ARRAY_OF_STRING, | |
["assert"] = a_type({ typename = "function", typeargs = TUPLE({ ARG_ALPHA, ARG_BETA }), args = TUPLE({ ALPHA, OPT_BETA }), rets = TUPLE({ ALPHA }) }), | |
["collectgarbage"] = a_type({ | |
typename = "poly", | |
types = { | |
a_type({ typename = "function", args = TUPLE({ a_type({ typename = "enum", enumset = { ["collect"] = true, ["count"] = true, ["stop"] = true, ["restart"] = true } }) }), rets = TUPLE({ NUMBER }) }), | |
a_type({ typename = "function", args = TUPLE({ a_type({ typename = "enum", enumset = { ["step"] = true, ["setpause"] = true, ["setstepmul"] = true } }), NUMBER }), rets = TUPLE({ NUMBER }) }), | |
a_type({ typename = "function", args = TUPLE({ a_type({ typename = "enum", enumset = { ["isrunning"] = true } }) }), rets = TUPLE({ BOOLEAN }) }), | |
a_type({ typename = "function", args = TUPLE({ STRING, OPT_NUMBER }), rets = TUPLE({ a_type({ typename = "union", types = { BOOLEAN, NUMBER } }) }) }), | |
}, | |
}), | |
["dofile"] = a_type({ typename = "function", args = TUPLE({ OPT_STRING }), rets = VARARG({ ANY }) }), | |
["error"] = a_type({ typename = "function", args = TUPLE({ ANY, NUMBER }), rets = TUPLE({}) }), | |
["getmetatable"] = a_type({ typename = "function", typeargs = TUPLE({ ARG_ALPHA }), args = TUPLE({ ALPHA }), rets = TUPLE({ NOMINAL_METATABLE_OF_ALPHA }) }), | |
["ipairs"] = a_type({ typename = "function", typeargs = TUPLE({ ARG_ALPHA }), args = TUPLE({ ARRAY_OF_ALPHA }), rets = TUPLE({ | |
a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ INTEGER, ALPHA }) }), | |
}), }), | |
["load"] = a_type({ typename = "function", args = TUPLE({ UNION({ STRING, LOAD_FUNCTION }), OPT_STRING, OPT_STRING, OPT_TABLE }), rets = TUPLE({ FUNCTION, STRING }) }), | |
["loadfile"] = a_type({ typename = "function", args = TUPLE({ OPT_STRING, OPT_STRING, OPT_TABLE }), rets = TUPLE({ FUNCTION, STRING }) }), | |
["next"] = a_type({ | |
typename = "poly", | |
types = { | |
a_type({ typeargs = TUPLE({ ARG_ALPHA, ARG_BETA }), typename = "function", args = TUPLE({ MAP_OF_ALPHA_TO_BETA, OPT_ALPHA }), rets = TUPLE({ ALPHA, BETA }) }), | |
a_type({ typeargs = TUPLE({ ARG_ALPHA }), typename = "function", args = TUPLE({ ARRAY_OF_ALPHA, OPT_ALPHA }), rets = TUPLE({ INTEGER, ALPHA }) }), | |
}, | |
}), | |
["pairs"] = a_type({ typename = "function", typeargs = TUPLE({ ARG_ALPHA, ARG_BETA }), args = TUPLE({ a_type({ typename = "map", keys = ALPHA, values = BETA }) }), rets = TUPLE({ | |
a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ ALPHA, BETA }) }), | |
}), }), | |
["pcall"] = a_type({ typename = "function", args = VARARG({ FUNCTION, ANY }), rets = TUPLE({ BOOLEAN, ANY }) }), | |
["xpcall"] = a_type({ typename = "function", args = VARARG({ FUNCTION, XPCALL_MSGH_FUNCTION, ANY }), rets = TUPLE({ BOOLEAN, ANY }) }), | |
["print"] = a_type({ typename = "function", args = VARARG({ ANY }), rets = TUPLE({}) }), | |
["rawequal"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ BOOLEAN }) }), | |
["rawget"] = a_type({ typename = "function", args = TUPLE({ TABLE, ANY }), rets = TUPLE({ ANY }) }), | |
["rawlen"] = a_type({ typename = "function", args = TUPLE({ UNION({ TABLE, STRING }) }), rets = TUPLE({ INTEGER }) }), | |
["rawset"] = a_type({ | |
typename = "poly", | |
types = { | |
a_type({ typeargs = TUPLE({ ARG_ALPHA, ARG_BETA }), typename = "function", args = TUPLE({ MAP_OF_ALPHA_TO_BETA, ALPHA, BETA }), rets = TUPLE({}) }), | |
a_type({ typeargs = TUPLE({ ARG_ALPHA }), typename = "function", args = TUPLE({ ARRAY_OF_ALPHA, NUMBER, ALPHA }), rets = TUPLE({}) }), | |
a_type({ typename = "function", args = TUPLE({ TABLE, ANY, ANY }), rets = TUPLE({}) }), | |
}, | |
}), | |
["require"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({}) }), | |
["select"] = a_type({ | |
typename = "poly", | |
types = { | |
a_type({ typename = "function", typeargs = TUPLE({ ARG_ALPHA }), args = VARARG({ NUMBER, ALPHA }), rets = TUPLE({ ALPHA }) }), | |
a_type({ typename = "function", args = VARARG({ NUMBER, ANY }), rets = TUPLE({ ANY }) }), | |
a_type({ typename = "function", args = VARARG({ STRING, ANY }), rets = TUPLE({ INTEGER }) }), | |
}, | |
}), | |
["setmetatable"] = a_type({ typeargs = TUPLE({ ARG_ALPHA }), typename = "function", args = TUPLE({ ALPHA, NOMINAL_METATABLE_OF_ALPHA }), rets = TUPLE({ ALPHA }) }), | |
["tonumber"] = a_type({ | |
typename = "poly", | |
types = { | |
a_type({ typename = "function", args = TUPLE({ ANY }), rets = TUPLE({ NUMBER }) }), | |
a_type({ typename = "function", args = TUPLE({ ANY, NUMBER }), rets = TUPLE({ INTEGER }) }), | |
}, | |
}), | |
["tostring"] = a_type({ typename = "function", args = TUPLE({ ANY }), rets = TUPLE({ STRING }) }), | |
["type"] = a_type({ typename = "function", args = TUPLE({ ANY }), rets = TUPLE({ STRING }) }), | |
["FILE"] = a_type({ | |
typename = "typetype", | |
def = a_type({ | |
typename = "record", | |
is_userdata = true, | |
fields = { | |
["close"] = a_type({ typename = "function", args = TUPLE({ NOMINAL_FILE }), rets = TUPLE({ BOOLEAN, STRING }) }), | |
["flush"] = a_type({ typename = "function", args = TUPLE({ NOMINAL_FILE }), rets = TUPLE({}) }), | |
["lines"] = a_type({ typename = "function", args = VARARG({ NOMINAL_FILE, a_type({ typename = "union", types = { STRING, NUMBER } }) }), rets = TUPLE({ | |
a_type({ typename = "function", args = TUPLE({}), rets = VARARG({ STRING }) }), | |
}), }), | |
["read"] = a_type({ typename = "function", args = TUPLE({ NOMINAL_FILE, UNION({ STRING, NUMBER }) }), rets = TUPLE({ STRING, STRING }) }), | |
["seek"] = a_type({ typename = "function", args = TUPLE({ NOMINAL_FILE, OPT_STRING, OPT_NUMBER }), rets = TUPLE({ INTEGER, STRING }) }), | |
["setvbuf"] = a_type({ typename = "function", args = TUPLE({ NOMINAL_FILE, STRING, OPT_NUMBER }), rets = TUPLE({}) }), | |
["write"] = a_type({ typename = "function", args = VARARG({ NOMINAL_FILE, STRING }), rets = TUPLE({ NOMINAL_FILE, STRING }) }), | |
}, | |
}), | |
}), | |
["metatable"] = a_type({ | |
typename = "typetype", | |
def = a_type({ | |
typename = "record", | |
typeargs = TUPLE({ ARG_ALPHA }), | |
fields = { | |
["__call"] = a_type({ typename = "function", args = VARARG({ ALPHA, ANY }), rets = VARARG({ ANY }) }), | |
["__gc"] = a_type({ typename = "function", args = TUPLE({ ALPHA }), rets = TUPLE({}) }), | |
["__index"] = ANY, | |
["__len"] = a_type({ typename = "function", args = TUPLE({ ALPHA }), rets = TUPLE({ ANY }) }), | |
["__mode"] = a_type({ typename = "enum", enumset = { ["k"] = true, ["v"] = true, ["kv"] = true } }), | |
["__newindex"] = ANY, | |
["__pairs"] = a_type({ typename = "function", typeargs = TUPLE({ ARG_ALPHA, ARG_BETA }), | |
args = TUPLE({ a_type({ typename = "map", keys = ALPHA, values = BETA }) }), | |
rets = TUPLE({ a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ ALPHA, BETA }) }) }), | |
}), | |
["__tostring"] = a_type({ typename = "function", args = TUPLE({ ALPHA }), rets = TUPLE({ STRING }) }), | |
["__name"] = STRING, | |
["__add"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }), | |
["__sub"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }), | |
["__mul"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }), | |
["__div"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }), | |
["__idiv"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }), | |
["__mod"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }), | |
["__pow"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }), | |
["__unm"] = a_type({ typename = "function", args = TUPLE({ ANY }), rets = TUPLE({ ANY }) }), | |
["__band"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }), | |
["__bor"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }), | |
["__bxor"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }), | |
["__bnot"] = a_type({ typename = "function", args = TUPLE({ ANY }), rets = TUPLE({ ANY }) }), | |
["__shl"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }), | |
["__shr"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }), | |
["__concat"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ ANY }) }), | |
["__eq"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ BOOLEAN }) }), | |
["__lt"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ BOOLEAN }) }), | |
["__le"] = a_type({ typename = "function", args = TUPLE({ ANY, ANY }), rets = TUPLE({ BOOLEAN }) }), | |
}, | |
}), | |
}), | |
["coroutine"] = a_type({ | |
typename = "record", | |
fields = { | |
["create"] = a_type({ typename = "function", args = TUPLE({ FUNCTION }), rets = TUPLE({ THREAD }) }), | |
["close"] = a_type({ typename = "function", args = TUPLE({ THREAD }), rets = TUPLE({ BOOLEAN, STRING }) }), | |
["isyieldable"] = a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ BOOLEAN }) }), | |
["resume"] = a_type({ typename = "function", args = VARARG({ THREAD, ANY }), rets = VARARG({ BOOLEAN, ANY }) }), | |
["running"] = a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ THREAD, BOOLEAN }) }), | |
["status"] = a_type({ typename = "function", args = TUPLE({ THREAD }), rets = TUPLE({ STRING }) }), | |
["wrap"] = a_type({ typename = "function", args = TUPLE({ FUNCTION }), rets = TUPLE({ FUNCTION }) }), | |
["yield"] = a_type({ typename = "function", args = VARARG({ ANY }), rets = VARARG({ ANY }) }), | |
}, | |
}), | |
["debug"] = a_type({ | |
typename = "record", | |
fields = { | |
["Info"] = a_type({ | |
typename = "typetype", | |
def = DEBUG_GETINFO_TABLE, | |
}), | |
["Hook"] = a_type({ | |
typename = "typetype", | |
def = DEBUG_HOOK_FUNCTION, | |
}), | |
["HookEvent"] = a_type({ | |
typename = "typetype", | |
def = DEBUG_HOOK_EVENT, | |
}), | |
["debug"] = a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({}) }), | |
["gethook"] = a_type({ typename = "function", args = TUPLE({ OPT_THREAD }), rets = TUPLE({ DEBUG_HOOK_FUNCTION, INTEGER }) }), | |
["getlocal"] = a_type({ | |
typename = "poly", | |
types = { | |
a_type({ typename = "function", args = TUPLE({ THREAD, FUNCTION, NUMBER }), rets = TUPLE({}) }), | |
a_type({ typename = "function", args = TUPLE({ FUNCTION, NUMBER }), rets = TUPLE({}) }), | |
}, | |
}), | |
["getmetatable"] = a_type({ typename = "function", typeargs = TUPLE({ ARG_ALPHA }), args = TUPLE({ ALPHA }), rets = TUPLE({ NOMINAL_METATABLE_OF_ALPHA }) }), | |
["getregistry"] = a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ TABLE }) }), | |
["getupvalue"] = a_type({ typename = "function", args = TUPLE({ FUNCTION, NUMBER }), rets = TUPLE({ ANY }) }), | |
["getuservalue"] = a_type({ typename = "function", args = TUPLE({ USERDATA, NUMBER }), rets = TUPLE({ ANY }) }), | |
["sethook"] = a_type({ | |
typename = "poly", | |
types = { | |
a_type({ typename = "function", args = TUPLE({ THREAD, DEBUG_HOOK_FUNCTION, STRING, NUMBER }), rets = TUPLE({}) }), | |
a_type({ typename = "function", args = TUPLE({ DEBUG_HOOK_FUNCTION, STRING, NUMBER }), rets = TUPLE({}) }), | |
}, | |
}), | |
["setlocal"] = a_type({ | |
typename = "poly", | |
types = { | |
a_type({ typename = "function", args = TUPLE({ THREAD, NUMBER, NUMBER, ANY }), rets = TUPLE({ STRING }) }), | |
a_type({ typename = "function", args = TUPLE({ NUMBER, NUMBER, ANY }), rets = TUPLE({ STRING }) }), | |
}, | |
}), | |
["setmetatable"] = a_type({ typeargs = TUPLE({ ARG_ALPHA }), typename = "function", args = TUPLE({ ALPHA, NOMINAL_METATABLE_OF_ALPHA }), rets = TUPLE({ ALPHA }) }), | |
["setupvalue"] = a_type({ typename = "function", args = TUPLE({ FUNCTION, NUMBER, ANY }), rets = TUPLE({ STRING }) }), | |
["setuservalue"] = a_type({ typename = "function", args = TUPLE({ USERDATA, ANY, NUMBER }), rets = TUPLE({ USERDATA }) }), | |
["traceback"] = a_type({ | |
typename = "poly", | |
types = { | |
a_type({ typename = "function", args = TUPLE({ THREAD, STRING, NUMBER }), rets = TUPLE({ STRING }) }), | |
a_type({ typename = "function", args = TUPLE({ STRING, NUMBER }), rets = TUPLE({ STRING }) }), | |
}, | |
}), | |
["upvalueid"] = a_type({ typename = "function", args = TUPLE({ FUNCTION, NUMBER }), rets = TUPLE({ USERDATA }) }), | |
["upvaluejoin"] = a_type({ typename = "function", args = TUPLE({ FUNCTION, NUMBER, FUNCTION, NUMBER }), rets = TUPLE({}) }), | |
["getinfo"] = a_type({ | |
typename = "poly", | |
types = { | |
a_type({ typename = "function", args = TUPLE({ ANY }), rets = TUPLE({ DEBUG_GETINFO_TABLE }) }), | |
a_type({ typename = "function", args = TUPLE({ ANY, STRING }), rets = TUPLE({ DEBUG_GETINFO_TABLE }) }), | |
a_type({ typename = "function", args = TUPLE({ ANY, ANY, STRING }), rets = TUPLE({ DEBUG_GETINFO_TABLE }) }), | |
}, | |
}), | |
}, | |
}), | |
["io"] = a_type({ | |
typename = "record", | |
fields = { | |
["close"] = a_type({ typename = "function", args = TUPLE({ OPT_NOMINAL_FILE }), rets = TUPLE({ BOOLEAN, STRING }) }), | |
["flush"] = a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({}) }), | |
["input"] = a_type({ typename = "function", args = TUPLE({ OPT_UNION({ STRING, NOMINAL_FILE }) }), rets = TUPLE({ NOMINAL_FILE }) }), | |
["lines"] = a_type({ typename = "function", args = VARARG({ OPT_STRING, a_type({ typename = "union", types = { STRING, NUMBER } }) }), rets = TUPLE({ | |
a_type({ typename = "function", args = TUPLE({}), rets = VARARG({ STRING }) }), | |
}), }), | |
["open"] = a_type({ typename = "function", args = TUPLE({ STRING, STRING }), rets = TUPLE({ NOMINAL_FILE, STRING }) }), | |
["output"] = a_type({ typename = "function", args = TUPLE({ OPT_UNION({ STRING, NOMINAL_FILE }) }), rets = TUPLE({ NOMINAL_FILE }) }), | |
["popen"] = a_type({ typename = "function", args = TUPLE({ STRING, STRING }), rets = TUPLE({ NOMINAL_FILE, STRING }) }), | |
["read"] = a_type({ typename = "function", args = TUPLE({ UNION({ STRING, NUMBER }) }), rets = TUPLE({ STRING, STRING }) }), | |
["stderr"] = NOMINAL_FILE, | |
["stdin"] = NOMINAL_FILE, | |
["stdout"] = NOMINAL_FILE, | |
["tmpfile"] = a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ NOMINAL_FILE }) }), | |
["type"] = a_type({ typename = "function", args = TUPLE({ ANY }), rets = TUPLE({ STRING }) }), | |
["write"] = a_type({ typename = "function", args = VARARG({ STRING }), rets = TUPLE({ NOMINAL_FILE, STRING }) }), | |
}, | |
}), | |
["math"] = a_type({ | |
typename = "record", | |
fields = { | |
["abs"] = a_type({ | |
typename = "poly", | |
types = { | |
a_type({ typename = "function", args = TUPLE({ INTEGER }), rets = TUPLE({ INTEGER }) }), | |
a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }), | |
}, | |
}), | |
["acos"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }), | |
["asin"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }), | |
["atan"] = a_type({ typename = "function", args = TUPLE({ NUMBER, OPT_NUMBER }), rets = TUPLE({ NUMBER }) }), | |
["atan2"] = a_type({ typename = "function", args = TUPLE({ NUMBER, NUMBER }), rets = TUPLE({ NUMBER }) }), | |
["ceil"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ INTEGER }) }), | |
["cos"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }), | |
["cosh"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }), | |
["deg"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }), | |
["exp"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }), | |
["floor"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ INTEGER }) }), | |
["fmod"] = a_type({ | |
typename = "poly", | |
types = { | |
a_type({ typename = "function", args = TUPLE({ INTEGER, INTEGER }), rets = TUPLE({ INTEGER }) }), | |
a_type({ typename = "function", args = TUPLE({ NUMBER, NUMBER }), rets = TUPLE({ NUMBER }) }), | |
}, | |
}), | |
["frexp"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER, NUMBER }) }), | |
["huge"] = NUMBER, | |
["ldexp"] = a_type({ typename = "function", args = TUPLE({ NUMBER, NUMBER }), rets = TUPLE({ NUMBER }) }), | |
["log"] = a_type({ typename = "function", args = TUPLE({ NUMBER, NUMBER }), rets = TUPLE({ NUMBER }) }), | |
["log10"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }), | |
["max"] = a_type({ | |
typename = "poly", | |
types = { | |
a_type({ typename = "function", args = VARARG({ INTEGER }), rets = TUPLE({ INTEGER }) }), | |
a_type({ typename = "function", typeargs = TUPLE({ ARG_ALPHA }), args = VARARG({ ALPHA }), rets = TUPLE({ ALPHA }) }), | |
a_type({ typename = "function", args = VARARG({ a_type({ typename = "union", types = { NUMBER, INTEGER } }) }), rets = TUPLE({ NUMBER }) }), | |
a_type({ typename = "function", args = VARARG({ ANY }), rets = TUPLE({ ANY }) }), | |
}, | |
}), | |
["maxinteger"] = INTEGER, | |
["min"] = a_type({ | |
typename = "poly", | |
types = { | |
a_type({ typename = "function", args = VARARG({ INTEGER }), rets = TUPLE({ INTEGER }) }), | |
a_type({ typename = "function", typeargs = TUPLE({ ARG_ALPHA }), args = VARARG({ ALPHA }), rets = TUPLE({ ALPHA }) }), | |
a_type({ typename = "function", args = VARARG({ a_type({ typename = "union", types = { NUMBER, INTEGER } }) }), rets = TUPLE({ NUMBER }) }), | |
a_type({ typename = "function", args = VARARG({ ANY }), rets = TUPLE({ ANY }) }), | |
}, | |
}), | |
["mininteger"] = INTEGER, | |
["modf"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ INTEGER, NUMBER }) }), | |
["pi"] = NUMBER, | |
["pow"] = a_type({ typename = "function", args = TUPLE({ NUMBER, NUMBER }), rets = TUPLE({ NUMBER }) }), | |
["rad"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }), | |
["random"] = a_type({ | |
typename = "poly", | |
types = { | |
a_type({ typename = "function", args = TUPLE({ NUMBER, NUMBER }), rets = TUPLE({ INTEGER }) }), | |
a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ NUMBER }) }), | |
}, | |
}), | |
["randomseed"] = a_type({ typename = "function", args = TUPLE({ NUMBER, NUMBER }), rets = TUPLE({ INTEGER, INTEGER }) }), | |
["sin"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }), | |
["sinh"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }), | |
["sqrt"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }), | |
["tan"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }), | |
["tanh"] = a_type({ typename = "function", args = TUPLE({ NUMBER }), rets = TUPLE({ NUMBER }) }), | |
["tointeger"] = a_type({ typename = "function", args = TUPLE({ ANY }), rets = TUPLE({ INTEGER }) }), | |
["type"] = a_type({ typename = "function", args = TUPLE({ ANY }), rets = TUPLE({ STRING }) }), | |
["ult"] = a_type({ typename = "function", args = TUPLE({ NUMBER, NUMBER }), rets = TUPLE({ BOOLEAN }) }), | |
}, | |
}), | |
["os"] = a_type({ | |
typename = "record", | |
fields = { | |
["clock"] = a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ NUMBER }) }), | |
["date"] = a_type({ | |
typename = "poly", | |
types = { | |
a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ STRING }) }), | |
a_type({ typename = "function", args = TUPLE({ OS_DATE_TABLE_FORMAT, NUMBER }), rets = TUPLE({ OS_DATE_TABLE }) }), | |
a_type({ typename = "function", args = TUPLE({ OPT_STRING, OPT_NUMBER }), rets = TUPLE({ STRING }) }), | |
}, | |
}), | |
["difftime"] = a_type({ typename = "function", args = TUPLE({ NUMBER, NUMBER }), rets = TUPLE({ NUMBER }) }), | |
["execute"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ BOOLEAN, STRING, INTEGER }) }), | |
["exit"] = a_type({ typename = "function", args = TUPLE({ UNION({ NUMBER, BOOLEAN }), BOOLEAN }), rets = TUPLE({}) }), | |
["getenv"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ STRING }) }), | |
["remove"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ BOOLEAN, STRING }) }), | |
["rename"] = a_type({ typename = "function", args = TUPLE({ STRING, STRING }), rets = TUPLE({ BOOLEAN, STRING }) }), | |
["setlocale"] = a_type({ typename = "function", args = TUPLE({ STRING, OPT_STRING }), rets = TUPLE({ STRING }) }), | |
["time"] = a_type({ | |
typename = "poly", | |
types = { | |
a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ INTEGER }) }), | |
a_type({ typename = "function", args = TUPLE({ OS_DATE_TABLE }), rets = TUPLE({ INTEGER }) }), | |
}, | |
}), | |
["tmpname"] = a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ STRING }) }), | |
}, | |
}), | |
["package"] = a_type({ | |
typename = "record", | |
fields = { | |
["config"] = STRING, | |
["cpath"] = STRING, | |
["loaded"] = a_type({ | |
typename = "map", | |
keys = STRING, | |
values = ANY, | |
}), | |
["loaders"] = a_type({ | |
typename = "array", | |
elements = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ ANY }) }), | |
}), | |
["loadlib"] = a_type({ typename = "function", args = TUPLE({ STRING, STRING }), rets = TUPLE({ FUNCTION }) }), | |
["path"] = STRING, | |
["preload"] = TABLE, | |
["searchers"] = a_type({ | |
typename = "array", | |
elements = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ ANY }) }), | |
}), | |
["searchpath"] = a_type({ typename = "function", args = TUPLE({ STRING, STRING, OPT_STRING, OPT_STRING }), rets = TUPLE({ STRING, STRING }) }), | |
}, | |
}), | |
["string"] = a_type({ | |
typename = "record", | |
fields = { | |
["byte"] = a_type({ | |
typename = "poly", | |
types = { | |
a_type({ typename = "function", args = TUPLE({ STRING, OPT_NUMBER }), rets = TUPLE({ INTEGER }) }), | |
a_type({ typename = "function", args = TUPLE({ STRING, NUMBER, NUMBER }), rets = VARARG({ INTEGER }) }), | |
}, | |
}), | |
["char"] = a_type({ typename = "function", args = VARARG({ NUMBER }), rets = TUPLE({ STRING }) }), | |
["dump"] = a_type({ typename = "function", args = TUPLE({ FUNCTION, OPT_BOOLEAN }), rets = TUPLE({ STRING }) }), | |
["find"] = a_type({ typename = "function", args = TUPLE({ STRING, STRING, OPT_NUMBER, OPT_BOOLEAN }), rets = VARARG({ INTEGER, INTEGER, STRING }) }), | |
["format"] = a_type({ typename = "function", args = VARARG({ STRING, ANY }), rets = TUPLE({ STRING }) }), | |
["gmatch"] = a_type({ typename = "function", args = TUPLE({ STRING, STRING }), rets = TUPLE({ | |
a_type({ typename = "function", args = TUPLE({}), rets = VARARG({ STRING }) }), | |
}), }), | |
["gsub"] = a_type({ | |
typename = "poly", | |
types = { | |
a_type({ typename = "function", args = TUPLE({ STRING, STRING, STRING, NUMBER }), rets = TUPLE({ STRING, INTEGER }) }), | |
a_type({ typename = "function", args = TUPLE({ STRING, STRING, a_type({ typename = "map", keys = STRING, values = STRING }), NUMBER }), rets = TUPLE({ STRING, INTEGER }) }), | |
a_type({ typename = "function", args = TUPLE({ STRING, STRING, a_type({ typename = "function", args = VARARG({ STRING }), rets = TUPLE({ STRING }) }) }), rets = TUPLE({ STRING, INTEGER }) }), | |
a_type({ typename = "function", args = TUPLE({ STRING, STRING, a_type({ typename = "function", args = VARARG({ STRING }), rets = TUPLE({ NUMBER }) }) }), rets = TUPLE({ STRING, INTEGER }) }), | |
a_type({ typename = "function", args = TUPLE({ STRING, STRING, a_type({ typename = "function", args = VARARG({ STRING }), rets = TUPLE({ BOOLEAN }) }) }), rets = TUPLE({ STRING, INTEGER }) }), | |
a_type({ typename = "function", args = TUPLE({ STRING, STRING, a_type({ typename = "function", args = VARARG({ STRING }), rets = TUPLE({}) }) }), rets = TUPLE({ STRING, INTEGER }) }), | |
}, | |
}), | |
["len"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ INTEGER }) }), | |
["lower"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ STRING }) }), | |
["match"] = a_type({ typename = "function", args = TUPLE({ STRING, STRING, NUMBER }), rets = VARARG({ STRING }) }), | |
["pack"] = a_type({ typename = "function", args = VARARG({ STRING, ANY }), rets = TUPLE({ STRING }) }), | |
["packsize"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ INTEGER }) }), | |
["rep"] = a_type({ typename = "function", args = TUPLE({ STRING, NUMBER }), rets = TUPLE({ STRING }) }), | |
["reverse"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ STRING }) }), | |
["sub"] = a_type({ typename = "function", args = TUPLE({ STRING, NUMBER, NUMBER }), rets = TUPLE({ STRING }) }), | |
["unpack"] = a_type({ typename = "function", args = TUPLE({ STRING, STRING, OPT_NUMBER }), rets = VARARG({ ANY }) }), | |
["upper"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ STRING }) }), | |
}, | |
}), | |
["table"] = a_type({ | |
typename = "record", | |
fields = { | |
["concat"] = a_type({ typename = "function", args = TUPLE({ ARRAY_OF_STRING, OPT_STRING, OPT_NUMBER, OPT_NUMBER }), rets = TUPLE({ STRING }) }), | |
["insert"] = a_type({ | |
typename = "poly", | |
types = { | |
a_type({ typename = "function", typeargs = TUPLE({ ARG_ALPHA }), args = TUPLE({ ARRAY_OF_ALPHA, NUMBER, ALPHA }), rets = TUPLE({}) }), | |
a_type({ typename = "function", typeargs = TUPLE({ ARG_ALPHA }), args = TUPLE({ ARRAY_OF_ALPHA, ALPHA }), rets = TUPLE({}) }), | |
}, | |
}), | |
["move"] = a_type({ | |
typename = "poly", | |
types = { | |
a_type({ typename = "function", typeargs = TUPLE({ ARG_ALPHA }), args = TUPLE({ ARRAY_OF_ALPHA, NUMBER, NUMBER, NUMBER }), rets = TUPLE({ ARRAY_OF_ALPHA }) }), | |
a_type({ typename = "function", typeargs = TUPLE({ ARG_ALPHA }), args = TUPLE({ ARRAY_OF_ALPHA, NUMBER, NUMBER, NUMBER, ARRAY_OF_ALPHA }), rets = TUPLE({ ARRAY_OF_ALPHA }) }), | |
}, | |
}), | |
["pack"] = a_type({ typename = "function", args = VARARG({ ANY }), rets = TUPLE({ TABLE }) }), | |
["remove"] = a_type({ typename = "function", typeargs = TUPLE({ ARG_ALPHA }), args = TUPLE({ ARRAY_OF_ALPHA, OPT_NUMBER }), rets = TUPLE({ ALPHA }) }), | |
["sort"] = a_type({ typename = "function", typeargs = TUPLE({ ARG_ALPHA }), args = TUPLE({ ARRAY_OF_ALPHA, OPT_TABLE_SORT_FUNCTION }), rets = TUPLE({}) }), | |
["unpack"] = a_type({ typename = "function", needs_compat = true, typeargs = TUPLE({ ARG_ALPHA }), args = TUPLE({ ARRAY_OF_ALPHA, NUMBER, NUMBER }), rets = VARARG({ ALPHA }) }), | |
}, | |
}), | |
["utf8"] = a_type({ | |
typename = "record", | |
fields = { | |
["char"] = a_type({ typename = "function", args = VARARG({ NUMBER }), rets = TUPLE({ STRING }) }), | |
["charpattern"] = STRING, | |
["codepoint"] = a_type({ typename = "function", args = TUPLE({ STRING, OPT_NUMBER, OPT_NUMBER }), rets = VARARG({ INTEGER }) }), | |
["codes"] = a_type({ typename = "function", args = TUPLE({ STRING }), rets = TUPLE({ | |
a_type({ typename = "function", args = TUPLE({}), rets = TUPLE({ NUMBER, STRING }) }), | |
}), }), | |
["len"] = a_type({ typename = "function", args = TUPLE({ STRING, NUMBER, NUMBER }), rets = TUPLE({ INTEGER }) }), | |
["offset"] = a_type({ typename = "function", args = TUPLE({ STRING, NUMBER, NUMBER }), rets = TUPLE({ INTEGER }) }), | |
}, | |
}), | |
["_VERSION"] = STRING, | |
} | |
for _, t in pairs(standard_library) do | |
fill_field_order(t) | |
if is_typetype(t) then | |
fill_field_order(t.def) | |
end | |
end | |
fill_field_order(OS_DATE_TABLE) | |
fill_field_order(DEBUG_GETINFO_TABLE) | |
NOMINAL_FILE.found = standard_library["FILE"] | |
NOMINAL_METATABLE_OF_ALPHA.found = standard_library["metatable"] | |
for name, typ in pairs(standard_library) do | |
globals[name] = { t = typ, needs_compat = stdlib_compat[name], is_const = true } | |
end | |
globals["@is_va"] = { t = ANY } | |
if not is_first_init then | |
last_typeid = save_typeid | |
end | |
return globals, standard_library | |
end | |
tl.init_env = function(lax, gen_compat, gen_target, predefined) | |
if gen_compat == true or gen_compat == nil then | |
gen_compat = "optional" | |
elseif gen_compat == false then | |
gen_compat = "off" | |
end | |
gen_compat = gen_compat | |
if not gen_target then | |
if _VERSION == "Lua 5.1" or _VERSION == "Lua 5.2" then | |
gen_target = "5.1" | |
else | |
gen_target = "5.3" | |
end | |
end | |
local globals, standard_library = init_globals(lax) | |
local env = { | |
ok = true, | |
modules = {}, | |
loaded = {}, | |
loaded_order = {}, | |
globals = globals, | |
gen_compat = gen_compat, | |
gen_target = gen_target, | |
} | |
for name, var in pairs(standard_library) do | |
if var.typename == "record" then | |
env.modules[name] = var | |
end | |
end | |
if predefined then | |
for _, name in ipairs(predefined) do | |
local module_type = require_module(name, lax, env) | |
if module_type == INVALID then | |
return nil, string.format("Error: could not predefine module '%s'", name) | |
end | |
end | |
end | |
return env | |
end | |
tl.type_check = function(ast, opts) | |
opts = opts or {} | |
local env = opts.env or tl.init_env(opts.lax, opts.gen_compat, opts.gen_target) | |
local lax = opts.lax | |
local filename = opts.filename | |
local st = { env.globals } | |
local symbol_list = {} | |
local symbol_list_n = 0 | |
local all_needs_compat = {} | |
local dependencies = {} | |
local warnings = {} | |
local errors = {} | |
local module_type | |
local function find_var(name, raw) | |
for i = #st, 1, -1 do | |
local scope = st[i] | |
if scope[name] then | |
if i == 1 and scope[name].needs_compat then | |
all_needs_compat[name] = true | |
end | |
if not raw then | |
scope[name].used = true | |
end | |
return scope[name] | |
end | |
end | |
end | |
local function simulate_g() | |
local globals = {} | |
for k, v in pairs(st[1]) do | |
if k:sub(1, 1) ~= "@" then | |
globals[k] = v.t | |
end | |
end | |
return a_type({ | |
typename = "record", | |
field_order = sorted_keys(globals), | |
fields = globals, | |
}), false | |
end | |
local function find_var_type(name, raw) | |
local var = find_var(name, raw) | |
if var then | |
return var.t, var.is_const | |
end | |
end | |
local function error_in_type(where, msg, ...) | |
local n = select("#", ...) | |
if n > 0 then | |
local showt = {} | |
for i = 1, n do | |
local t = select(i, ...) | |
if t.typename == "invalid" then | |
return nil | |
end | |
showt[i] = show_type(t) | |
end | |
msg = msg:format(_tl_table_unpack(showt)) | |
end | |
return { | |
y = where.y, | |
x = where.x, | |
msg = msg, | |
filename = where.filename or filename, | |
} | |
end | |
local function type_error(t, msg, ...) | |
local e = error_in_type(t, msg, ...) | |
if e then | |
table.insert(errors, e) | |
return true | |
else | |
return false | |
end | |
end | |
local function find_type(names, accept_typearg) | |
local typ = find_var_type(names[1]) | |
if not typ then | |
return nil | |
end | |
if typ.found then | |
typ = typ.found | |
end | |
for i = 2, #names do | |
local fields = typ.fields or (typ.def and typ.def.fields) | |
if fields then | |
typ = fields[names[i]] | |
if typ == nil then | |
return nil | |
end | |
if typ.found then | |
typ = typ.found | |
end | |
else | |
return nil | |
end | |
end | |
if is_typetype(typ) or (accept_typearg and typ.typename == "typearg") then | |
return typ | |
end | |
end | |
local function union_type(t) | |
if is_typetype(t) then | |
return union_type(t.def) | |
elseif t.typename == "tuple" then | |
return union_type(t[1]) | |
elseif t.typename == "nominal" then | |
local typetype = t.found or find_type(t.names) | |
if not typetype then | |
return "table" | |
end | |
return union_type(typetype) | |
elseif t.typename == "record" then | |
if t.is_userdata then | |
return "userdata" | |
end | |
return "table" | |
elseif table_types[t.typename] then | |
return "table" | |
else | |
return t.typename | |
end | |
end | |
local function is_valid_union(typ) | |
local n_table_types = 0 | |
local n_function_types = 0 | |
local n_userdata_types = 0 | |
local n_string_enum = 0 | |
local has_primitive_string_type = false | |
for _, t in ipairs(typ.types) do | |
local ut = union_type(t) | |
if ut == "userdata" then | |
n_userdata_types = n_userdata_types + 1 | |
if n_userdata_types > 1 then | |
return false, "cannot discriminate a union between multiple userdata types: %s" | |
end | |
elseif ut == "table" then | |
n_table_types = n_table_types + 1 | |
if n_table_types > 1 then | |
return false, "cannot discriminate a union between multiple table types: %s" | |
end | |
elseif ut == "function" then | |
n_function_types = n_function_types + 1 | |
if n_function_types > 1 then | |
return false, "cannot discriminate a union between multiple function types: %s" | |
end | |
elseif ut == "enum" or (ut == "string" and not has_primitive_string_type) then | |
n_string_enum = n_string_enum + 1 | |
if n_string_enum > 1 then | |
return false, "cannot discriminate a union between multiple string/enum types: %s" | |
end | |
if ut == "string" then | |
has_primitive_string_type = true | |
end | |
end | |
end | |
return true | |
end | |
local function resolve_typetype(t) | |
if is_typetype(t) then | |
return t.def | |
else | |
return t | |
end | |
end | |
local no_nested_types = { | |
["string"] = true, | |
["number"] = true, | |
["integer"] = true, | |
["boolean"] = true, | |
["thread"] = true, | |
["any"] = true, | |
["enum"] = true, | |
["nil"] = true, | |
["unknown"] = true, | |
} | |
local function resolve_typevars(typ) | |
local errs | |
local seen = {} | |
local function resolve(t) | |
if no_nested_types[t.typename] or (t.typename == "nominal" and not t.typevals) then | |
return t | |
end | |
seen = seen or {} | |
if seen[t] then | |
return seen[t] | |
end | |
local orig_t = t | |
if t.typename == "typevar" then | |
t = find_var_type(t.typevar) | |
local rt | |
if not t then | |
rt = orig_t | |
elseif t.typename == "string" then | |
rt = STRING | |
elseif no_nested_types[t.typename] or | |
(t.typename == "nominal" and not t.typevals) then | |
rt = t | |
end | |
if rt then | |
seen[orig_t] = rt | |
return rt | |
end | |
end | |
local copy = {} | |
seen[orig_t] = copy | |
copy.typename = t.typename | |
copy.filename = t.filename | |
copy.typeid = t.typeid | |
copy.x = t.x | |
copy.y = t.y | |
copy.yend = t.yend | |
copy.xend = t.xend | |
copy.names = t.names | |
for i, tf in ipairs(t) do | |
copy[i] = resolve(tf) | |
end | |
if t.typename == "array" then | |
copy.elements = resolve(t.elements) | |
elseif t.typename == "typearg" then | |
copy.typearg = t.typearg | |
elseif t.typename == "typevar" then | |
copy.typevar = t.typevar | |
elseif is_typetype(t) then | |
copy.def = resolve(t.def) | |
elseif t.typename == "nominal" then | |
copy.typevals = resolve(t.typevals) | |
copy.found = t.found | |
elseif t.typename == "function" then | |
if t.typeargs then | |
copy.typeargs = {} | |
for i, tf in ipairs(t.typeargs) do | |
copy.typeargs[i] = resolve(tf) | |
end | |
end | |
copy.is_method = t.is_method | |
copy.args = resolve(t.args) | |
copy.rets = resolve(t.rets) | |
elseif t.typename == "record" or t.typename == "arrayrecord" then | |
if t.typeargs then | |
copy.typeargs = {} | |
for i, tf in ipairs(t.typeargs) do | |
copy.typeargs[i] = resolve(tf) | |
end | |
end | |
if t.elements then | |
copy.elements = resolve(t.elements) | |
end | |
copy.fields = {} | |
for _, k in ipairs(t.field_order) do | |
copy.fields[k] = resolve(t.fields[k]) | |
end | |
copy.field_order = t.field_order | |
if t.meta_fields then | |
copy.meta_fields = {} | |
for _, k in ipairs(t.meta_field_order) do | |
copy.meta_fields[k] = resolve(t.meta_fields[k]) | |
end | |
copy.meta_field_order = t.meta_field_order | |
end | |
elseif t.typename == "map" then | |
copy.keys = resolve(t.keys) | |
copy.values = resolve(t.values) | |
elseif t.typename == "union" then | |
copy.types = {} | |
for i, tf in ipairs(t.types) do | |
copy.types[i] = resolve(tf) | |
end | |
local ok, err = is_valid_union(copy) | |
if not ok then | |
errs = errs or {} | |
table.insert(errs, error_in_type(t, err, t)) | |
end | |
elseif t.typename == "poly" or t.typename == "tupletable" then | |
copy.types = {} | |
for i, tf in ipairs(t.types) do | |
copy.types[i] = resolve(tf) | |
end | |
elseif t.typename == "tuple" then | |
copy.is_va = t.is_va | |
end | |
return copy | |
end | |
local copy = resolve(typ) | |
if errs then | |
return false, INVALID, errs | |
end | |
return true, copy | |
end | |
local function infer_var(emptytable, t, node) | |
local is_global = (emptytable.declared_at and emptytable.declared_at.kind == "global_declaration") | |
local nst = is_global and 1 or #st | |
for i = nst, 1, -1 do | |
local scope = st[i] | |
if scope[emptytable.assigned_to] then | |
scope[emptytable.assigned_to] = { | |
t = t, | |
is_const = false, | |
} | |
t.inferred_at = node | |
t.inferred_at_file = filename | |
end | |
end | |
end | |
local function find_global(name) | |
local scope = st[1] | |
if scope[name] then | |
return scope[name].t, scope[name].is_const | |
end | |
end | |
local function resolve_tuple(t) | |
if t.typename == "tuple" then | |
t = t[1] | |
end | |
if t == nil then | |
return NIL | |
end | |
return t | |
end | |
local function node_warning(tag, node, fmt, ...) | |
table.insert(warnings, { | |
y = node.y, | |
x = node.x, | |
msg = fmt:format(...), | |
filename = filename, | |
tag = tag, | |
}) | |
end | |
local function node_error(node, msg, ...) | |
type_error(node, msg, ...) | |
node.type = INVALID | |
return node.type | |
end | |
local function terr(t, s, ...) | |
return { error_in_type(t, s, ...) } | |
end | |
local function add_unknown(node, name) | |
node_warning("unknown", node, "unknown variable: %s", name) | |
end | |
local function redeclaration_warning(node, old_var) | |
if node.tk:sub(1, 1) == "_" then return end | |
if old_var.declared_at then | |
node_warning("redeclaration", node, "redeclaration of variable '%s' (originally declared at %d:%d)", node.tk, old_var.declared_at.y, old_var.declared_at.x) | |
else | |
node_warning("redeclaration", node, "redeclaration of variable '%s'", node.tk) | |
end | |
end | |
local function check_if_redeclaration(new_name, at) | |
local old = find_var(new_name, true) | |
if old then | |
redeclaration_warning(at, old) | |
end | |
end | |
local function unused_warning(name, var) | |
local prefix = name:sub(1, 1) | |
if var.declared_at and | |
not var.is_narrowed and | |
prefix ~= "_" and | |
prefix ~= "@" then | |
if name:sub(1, 2) == "::" then | |
node_warning("unused", var.declared_at, "unused label %s", name) | |
else | |
node_warning( | |
"unused", | |
var.declared_at, | |
"unused %s %s: %s", | |
var.is_func_arg and "argument" or | |
var.t.typename == "function" and "function" or | |
is_typetype(var.t) and "type" or | |
"variable", | |
name, | |
show_type(var.t)) | |
end | |
end | |
end | |
local function shallow_copy(t) | |
local copy = {} | |
for k, v in pairs(t) do | |
copy[k] = v | |
end | |
return copy | |
end | |
local function reserve_symbol_list_slot(node) | |
symbol_list_n = symbol_list_n + 1 | |
node.symbol_list_slot = symbol_list_n | |
end | |
local function add_var(node, var, valtype, is_const, is_narrowing, dont_check_redeclaration) | |
if lax and node and is_unknown(valtype) and (var ~= "self" and var ~= "...") and not is_narrowing then | |
add_unknown(node, var) | |
end | |
local scope = st[#st] | |
local old_var = scope[var] | |
if not is_const then | |
valtype = shallow_copy(valtype) | |
valtype.tk = nil | |
end | |
if old_var and is_narrowing then | |
if not old_var.is_narrowed then | |
old_var.narrowed_from = old_var.t | |
end | |
old_var.is_narrowed = true | |
old_var.t = valtype | |
else | |
if not dont_check_redeclaration and | |
node and | |
not is_narrowing and | |
var ~= "self" and | |
var ~= "..." and | |
var:sub(1, 1) ~= "@" then | |
check_if_redeclaration(var, node) | |
end | |
scope[var] = { t = valtype, is_const = is_const, is_narrowed = is_narrowing, declared_at = node } | |
if old_var then | |
if not old_var.used then | |
unused_warning(var, old_var) | |
end | |
end | |
end | |
if node and valtype.typename ~= "unresolved" and valtype.typename ~= "none" then | |
node.type = node.type or valtype | |
local slot | |
if node.symbol_list_slot then | |
slot = node.symbol_list_slot | |
else | |
symbol_list_n = symbol_list_n + 1 | |
slot = symbol_list_n | |
end | |
symbol_list[slot] = { y = node.y, x = node.x, name = var, typ = assert(scope[var].t) } | |
end | |
return scope[var] | |
end | |
local CompareTypes = {} | |
local function compare_and_infer_typevars(t1, t2, comp) | |
if t1.typevar == t2.typevar then | |
return true | |
end | |
local typevar = t2.typevar or t1.typevar | |
local vt = find_var_type(typevar) | |
if vt then | |
if t2.typevar then | |
return comp(t1, vt) | |
else | |
return comp(vt, t2) | |
end | |
else | |
local other = t2.typevar and t1 or t2 | |
local ok, resolved, errs = resolve_typevars(other) | |
if not ok then | |
return false, errs | |
end | |
if resolved.typename ~= "unknown" then | |
resolved = resolve_typetype(resolved) | |
add_var(nil, typevar, resolved) | |
end | |
return true | |
end | |
end | |
local function add_errs_prefixing(src, dst, prefix, node) | |
if not src then | |
return | |
end | |
for _, err in ipairs(src) do | |
err.msg = prefix .. err.msg | |
if node and node.y and ( | |
(err.filename ~= filename) or | |
(not err.y) or | |
(node.y > err.y or (node.y == err.y and node.x > err.x))) then | |
err.y = node.y | |
err.x = node.x | |
err.filename = filename | |
end | |
table.insert(dst, err) | |
end | |
end | |
local is_a | |
local TypeGetter = {} | |
local function match_record_fields(t1, t2, cmp) | |
cmp = cmp or is_a | |
local fielderrs = {} | |
for _, k in ipairs(t1.field_order) do | |
local f = t1.fields[k] | |
local t2k = t2(k) | |
if t2k == nil then | |
if not lax then | |
table.insert(fielderrs, error_in_type(f, "unknown field " .. k)) | |
end | |
else | |
local __, errs = cmp(f, t2k) | |
add_errs_prefixing(errs, fielderrs, "record field doesn't match: " .. k .. ": ") | |
end | |
end | |
if #fielderrs > 0 then | |
return false, fielderrs | |
end | |
return true | |
end | |
local function match_fields_to_record(t1, t2, cmp) | |
local ok, fielderrs = match_record_fields(t1, function(k) return t2.fields[k] end, cmp) | |
if not ok then | |
local errs = {} | |
add_errs_prefixing(errs, fielderrs, show_type(t1) .. " is not a " .. show_type(t2) .. ": ") | |
return false, errs | |
end | |
return true | |
end | |
local function match_fields_to_map(t1, t2) | |
if not match_record_fields(t1, function(_) return t2.values end) then | |
return false, { error_in_type(t1, "record is not a valid map; not all fields have the same type") } | |
end | |
return true | |
end | |
local function arg_check(cmp, a, b, at, n, errs) | |
local matches, match_errs = cmp(a, b) | |
if not matches then | |
add_errs_prefixing(match_errs, errs, "argument " .. n .. ": ", at) | |
return false | |
end | |
return true | |
end | |
local same_type | |
local function has_all_types_of(t1s, t2s, cmp) | |
for _, t1 in ipairs(t1s) do | |
local found = false | |
for _, t2 in ipairs(t2s) do | |
if cmp(t2, t1) then | |
found = true | |
break | |
end | |
end | |
if not found then | |
return false | |
end | |
end | |
return true | |
end | |
local function any_errors(all_errs) | |
if #all_errs == 0 then | |
return true | |
else | |
return false, all_errs | |
end | |
end | |
local function close_nested_records(t) | |
for _, ft in pairs(t.fields) do | |
if is_typetype(ft) then | |
ft.closed = true | |
if is_record_type(ft.def) then | |
close_nested_records(ft.def) | |
end | |
end | |
end | |
end | |
local function close_types(vars) | |
for _, var in pairs(vars) do | |
if is_typetype(var.t) then | |
var.t.closed = true | |
if is_record_type(var.t.def) then | |
close_nested_records(var.t.def) | |
end | |
end | |
end | |
end | |
local Unused = {} | |
local function check_for_unused_vars(vars) | |
local list = {} | |
for name, var in pairs(vars) do | |
if var.declared_at and not var.used then | |
table.insert(list, { y = var.declared_at.y, x = var.declared_at.x, name = name, var = var }) | |
end | |
end | |
if list[1] then | |
table.sort(list, function(a, b) | |
return a.y < a.y or (a.y == b.y and a.x < b.x) | |
end) | |
for _, u in ipairs(list) do | |
unused_warning(u.name, u.var) | |
end | |
end | |
end | |
local function begin_scope(node) | |
table.insert(st, {}) | |
if node then | |
symbol_list_n = symbol_list_n + 1 | |
symbol_list[symbol_list_n] = { y = node.y, x = node.x, name = "@{" } | |
end | |
end | |
local function end_scope(node) | |
local unresolved = st[#st]["@unresolved"] | |
if unresolved then | |
local upper = st[#st - 1]["@unresolved"] | |
if upper then | |
for name, nodes in pairs(unresolved.t.labels) do | |
for _, n in ipairs(nodes) do | |
upper.t.labels[name] = upper.t.labels[name] or {} | |
table.insert(upper.t.labels[name], n) | |
end | |
end | |
for name, types in pairs(unresolved.t.nominals) do | |
for _, typ in ipairs(types) do | |
upper.t.nominals[name] = upper.t.nominals[name] or {} | |
table.insert(upper.t.nominals[name], typ) | |
end | |
end | |
else | |
st[#st - 1]["@unresolved"] = unresolved | |
end | |
end | |
close_types(st[#st]) | |
check_for_unused_vars(st[#st]) | |
table.remove(st) | |
if node then | |
if symbol_list[symbol_list_n].name == "@{" then | |
symbol_list[symbol_list_n] = nil | |
symbol_list_n = symbol_list_n - 1 | |
else | |
symbol_list_n = symbol_list_n + 1 | |
symbol_list[symbol_list_n] = { y = assert(node.yend), x = assert(node.xend), name = "@}" } | |
end | |
end | |
end | |
local end_scope_and_none_type = function(node, _children) | |
end_scope(node) | |
node.type = NONE | |
return node.type | |
end | |
local function resolve_typevars_at(t, where) | |
assert(where) | |
local ok, typ, errs = resolve_typevars(t) | |
if not ok then | |
assert(where.y) | |
add_errs_prefixing(errs, errors, "", where) | |
end | |
return typ | |
end | |
local resolve_nominal | |
do | |
local function match_typevals(t, def) | |
if t.typevals and def.typeargs then | |
if #t.typevals ~= #def.typeargs then | |
type_error(t, "mismatch in number of type arguments") | |
return nil | |
end | |
begin_scope() | |
for i, tt in ipairs(t.typevals) do | |
add_var(nil, def.typeargs[i].typearg, tt) | |
end | |
local ret = resolve_typevars_at(def, t) | |
end_scope() | |
return ret | |
elseif t.typevals then | |
type_error(t, "spurious type arguments") | |
return nil | |
elseif def.typeargs then | |
type_error(t, "missing type arguments in %s", def) | |
return nil | |
else | |
return def | |
end | |
end | |
resolve_nominal = function(t) | |
if t.resolved then | |
return t.resolved | |
end | |
local resolved | |
local typetype = t.found or find_type(t.names) | |
if not typetype then | |
type_error(t, "unknown type %s", t) | |
elseif is_typetype(typetype) then | |
if typetype.is_alias then | |
typetype = typetype.def.found | |
assert(is_typetype(typetype)) | |
end | |
assert(typetype.def.typename ~= "nominal") | |
resolved = match_typevals(t, typetype.def) | |
else | |
type_error(t, table.concat(t.names, ".") .. " is not a type") | |
end | |
if not resolved then | |
resolved = a_type({ typename = "bad_nominal", names = t.names }) | |
end | |
if not t.filename then | |
t.filename = resolved.filename | |
if t.x == nil and t.y == nil then | |
t.x = resolved.x | |
t.y = resolved.y | |
end | |
end | |
t.found = typetype | |
t.resolved = resolved | |
return resolved | |
end | |
end | |
local function are_same_nominals(t1, t2) | |
local same_names | |
if t1.found and t2.found then | |
same_names = t1.found.typeid == t2.found.typeid | |
else | |
local ft1 = t1.found or find_type(t1.names) | |
local ft2 = t2.found or find_type(t2.names) | |
if ft1 and ft2 then | |
same_names = ft1.typeid == ft2.typeid | |
else | |
if not ft1 then | |
type_error(t1, "unknown type %s", t1) | |
end | |
if not ft2 then | |
type_error(t2, "unknown type %s", t2) | |
end | |
return false, {} | |
end | |
end | |
if same_names then | |
if t1.typevals == nil and t2.typevals == nil then | |
return true | |
elseif t1.typevals and t2.typevals and #t1.typevals == #t2.typevals then | |
local all_errs = {} | |
for i = 1, #t1.typevals do | |
local _, errs = same_type(t1.typevals[i], t2.typevals[i]) | |
add_errs_prefixing(errs, all_errs, "type parameter <" .. show_type(t2.typevals[i]) .. ">: ", t1) | |
end | |
if #all_errs == 0 then | |
return true | |
else | |
return false, all_errs | |
end | |
end | |
else | |
local t1name = show_type(t1) | |
local t2name = show_type(t2) | |
if t1name == t2name then | |
local t1r = resolve_nominal(t1) | |
if t1r.filename then | |
t1name = t1name .. " (defined in " .. t1r.filename .. ":" .. t1r.y .. ")" | |
end | |
local t2r = resolve_nominal(t2) | |
if t2r.filename then | |
t2name = t2name .. " (defined in " .. t2r.filename .. ":" .. t2r.y .. ")" | |
end | |
end | |
return false, terr(t1, t1name .. " is not a " .. t2name) | |
end | |
end | |
local is_known_table_type | |
local resolve_tuple_and_nominal = nil | |
same_type = function(t1, t2) | |
assert(type(t1) == "table") | |
assert(type(t2) == "table") | |
if t1.typename == "typevar" or t2.typename == "typevar" then | |
return compare_and_infer_typevars(t1, t2, same_type) | |
end | |
if t1.typename == "emptytable" and is_known_table_type(resolve_tuple_and_nominal(t2)) then | |
return true | |
end | |
if t1.typename ~= t2.typename then | |
return false, terr(t1, "got %s, expected %s", t1, t2) | |
end | |
if t1.typename == "array" then | |
return same_type(t1.elements, t2.elements) | |
elseif t1.typename == "tupletable" then | |
local all_errs = {} | |
for i = 1, math.min(#t1.types, #t2.types) do | |
local ok, err = same_type(t1.types[i], t2.types[i]) | |
if not ok then | |
add_errs_prefixing(err, all_errs, "values", t1) | |
end | |
end | |
return any_errors(all_errs) | |
elseif t1.typename == "map" then | |
local all_errs = {} | |
local k_ok, k_errs = same_type(t1.keys, t2.keys) | |
if not k_ok then | |
add_errs_prefixing(k_errs, all_errs, "keys", t1) | |
end | |
local v_ok, v_errs = same_type(t1.values, t2.values) | |
if not v_ok then | |
add_errs_prefixing(v_errs, all_errs, "values", t1) | |
end | |
return any_errors(all_errs) | |
elseif t1.typename == "union" then | |
if has_all_types_of(t1.types, t2.types, same_type) and | |
has_all_types_of(t2.types, t1.types, same_type) then | |
return true | |
else | |
return false, terr(t1, "got %s, expected %s", t1, t2) | |
end | |
elseif t1.typename == "nominal" then | |
return are_same_nominals(t1, t2) | |
elseif t1.typename == "record" then | |
return match_fields_to_record(t1, t2, same_type) and | |
match_fields_to_record(t2, t1, same_type) | |
elseif t1.typename == "function" then | |
if #t1.args ~= #t2.args then | |
return false, terr(t1, "different number of input arguments: got " .. #t1.args .. ", expected " .. #t2.args) | |
end | |
if #t1.rets ~= #t2.rets then | |
return false, terr(t1, "different number of return values: got " .. #t1.args .. ", expected " .. #t2.args) | |
end | |
local all_errs = {} | |
for i = 1, #t1.args do | |
arg_check(same_type, t1.args[i], t2.args[i], t1, i, all_errs) | |
end | |
for i = 1, #t1.rets do | |
local _, errs = same_type(t1.rets[i], t2.rets[i]) | |
add_errs_prefixing(errs, all_errs, "return " .. i, t1) | |
end | |
return any_errors(all_errs) | |
elseif t1.typename == "arrayrecord" then | |
local ok, errs = same_type(t1.elements, t2.elements) | |
if not ok then | |
return ok, errs | |
end | |
return match_fields_to_record(t1, t2, same_type) and | |
match_fields_to_record(t2, t1, same_type) | |
end | |
return true | |
end | |
local function unite(types, flatten_constants) | |
if #types == 1 then | |
return types[1] | |
end | |
local ts = {} | |
local stack = {} | |
local types_seen = {} | |
types_seen[NIL.typeid] = true | |
types_seen["nil"] = true | |
local i = 1 | |
while types[i] or stack[1] do | |
local t | |
if stack[1] then | |
t = table.remove(stack) | |
else | |
t = types[i] | |
i = i + 1 | |
end | |
t = resolve_tuple(t) | |
if t.typename == "union" then | |
for _, s in ipairs(t.types) do | |
table.insert(stack, s) | |
end | |
else | |
if primitive[t.typename] and (flatten_constants or not t.tk) then | |
if not types_seen[t.typename] then | |
types_seen[t.typename] = true | |
table.insert(ts, t) | |
end | |
else | |
local typeid = t.typeid | |
if t.typename == "nominal" then | |
typeid = resolve_nominal(t).typeid | |
end | |
if not types_seen[typeid] then | |
types_seen[typeid] = true | |
table.insert(ts, t) | |
end | |
end | |
end | |
end | |
if #ts == 1 then | |
return ts[1] | |
else | |
return a_type({ | |
typename = "union", | |
types = ts, | |
}) | |
end | |
end | |
local function combine_errs(...) | |
local errs | |
for i = 1, select("#", ...) do | |
local e = select(i, ...) | |
if e then | |
errs = errs or {} | |
for _, err in ipairs(e) do | |
table.insert(errs, err) | |
end | |
end | |
end | |
if not errs then | |
return true | |
else | |
return false, errs | |
end | |
end | |
local known_table_types = { | |
array = true, | |
map = true, | |
record = true, | |
arrayrecord = true, | |
tupletable = true, | |
} | |
is_known_table_type = function(t) | |
return known_table_types[t.typename] | |
end | |
local expand_type | |
local function arraytype_from_tuple(where, tupletype) | |
local element_type = unite(tupletype.types) | |
local valid = element_type.typename ~= "union" and true or is_valid_union(element_type) | |
if valid then | |
return a_type({ | |
elements = element_type, | |
typename = "array", | |
}) | |
end | |
local arr_type = a_type({ | |
elements = tupletype.types[1], | |
typename = "array", | |
}) | |
for i = 2, #tupletype.types do | |
arr_type = expand_type(where, arr_type, a_type({ elements = tupletype.types[i], typename = "array" })) | |
if not arr_type or not arr_type.elements then | |
return nil, terr(tupletype, "unable to convert tuple %s to array", tupletype) | |
end | |
end | |
return arr_type | |
end | |
is_a = function(t1, t2, for_equality) | |
assert(type(t1) == "table") | |
assert(type(t2) == "table") | |
if lax and (is_unknown(t1) or is_unknown(t2)) then | |
return true | |
end | |
if t1.typename == "bad_nominal" or t2.typename == "bad_nominal" then | |
return false | |
end | |
if t1.typename == "nil" then | |
return true | |
end | |
if t2.typename ~= "tuple" then | |
t1 = resolve_tuple(t1) | |
end | |
if t2.typename == "tuple" and t1.typename ~= "tuple" then | |
t1 = a_type({ | |
typename = "tuple", | |
[1] = t1, | |
}) | |
end | |
if t1.typename == "typevar" or t2.typename == "typevar" then | |
return compare_and_infer_typevars(t1, t2, is_a) | |
end | |
if t2.typename == "any" then | |
return true | |
elseif t1.typename == "union" then | |
for _, t in ipairs(t1.types) do | |
if not is_a(t, t2, for_equality) then | |
return false, terr(t1, "got %s, expected %s", t1, t2) | |
end | |
end | |
return true | |
elseif t2.typename == "union" then | |
for _, t in ipairs(t2.types) do | |
if is_a(t1, t, for_equality) then | |
return true | |
end | |
end | |
elseif t2.typename == "poly" then | |
for _, t in ipairs(t2.types) do | |
if not is_a(t1, t, for_equality) then | |
return false, terr(t1, "cannot match against all alternatives of the polymorphic type") | |
end | |
end | |
return true | |
elseif t1.typename == "poly" then | |
for _, t in ipairs(t1.types) do | |
if is_a(t, t2, for_equality) then | |
return true | |
end | |
end | |
return false, terr(t1, "cannot match against any alternatives of the polymorphic type") | |
elseif t1.typename == "nominal" and t2.typename == "nominal" then | |
local same, err = are_same_nominals(t1, t2) | |
if same then | |
return true | |
end | |
local t1r = resolve_tuple_and_nominal(t1) | |
local t2r = resolve_tuple_and_nominal(t2) | |
if is_record_type(t1r) and is_record_type(t2r) then | |
return same, err | |
else | |
return is_a(t1r, t2r, for_equality) | |
end | |
elseif t1.typename == "enum" and t2.typename == "string" then | |
local ok | |
if for_equality then | |
ok = t2.tk and t1.enumset[unquote(t2.tk)] | |
else | |
ok = true | |
end | |
if ok then | |
return true | |
else | |
return false, terr(t1, "enum is incompatible with %s", t2) | |
end | |
elseif t1.typename == "integer" and t2.typename == "number" then | |
return true | |
elseif t1.typename == "string" and t2.typename == "enum" then | |
local ok = t1.tk and t2.enumset[unquote(t1.tk)] | |
if ok then | |
return true | |
else | |
if t1.tk then | |
return false, terr(t1, "%s is not a member of %s", t1, t2) | |
else | |
return false, terr(t1, "string is not a %s", t2) | |
end | |
end | |
elseif t1.typename == "nominal" or t2.typename == "nominal" then | |
local t1r = resolve_tuple_and_nominal(t1) | |
local t2r = resolve_tuple_and_nominal(t2) | |
local ok, errs = is_a(t1r, t2r, for_equality) | |
if errs and #errs == 1 then | |
if errs[1].msg:match("^got ") then | |
errs = terr(t1, "got %s, expected %s", t1, t2) | |
end | |
end | |
return ok, errs | |
elseif t1.typename == "emptytable" and is_known_table_type(t2) then | |
return true | |
elseif t2.typename == "array" then | |
if is_array_type(t1) then | |
if is_a(t1.elements, t2.elements) then | |
return true | |
end | |
elseif t1.typename == "tupletable" then | |
if t2.inferred_len and t2.inferred_len > #t1.types then | |
return false, terr(t1, "incompatible length, expected maximum length of " .. tostring(#t1.types) .. ", got " .. tostring(t2.inferred_len)) | |
end | |
local t1a, err = arraytype_from_tuple(t1.inferred_at, t1) | |
if not t1a then | |
return false, err | |
end | |
if not is_a(t1a, t2) then | |
return false, terr(t2, "got %s (from %s), expected %s", t1a, t1, t2) | |
end | |
return true | |
elseif t1.typename == "map" then | |
local _, errs_keys, errs_values | |
_, errs_keys = is_a(t1.keys, INTEGER) | |
_, errs_values = is_a(t1.values, t2.elements) | |
return combine_errs(errs_keys, errs_values) | |
end | |
elseif t2.typename == "record" then | |
if is_record_type(t1) then | |
return match_fields_to_record(t1, t2) | |
elseif is_typetype(t1) and is_record_type(t1.def) then | |
return is_a(t1.def, t2, for_equality) | |
end | |
elseif t2.typename == "arrayrecord" then | |
if t1.typename == "array" then | |
return is_a(t1.elements, t2.elements) | |
elseif t1.typename == "tupletable" then | |
if t2.inferred_len and t2.inferred_len > #t1.types then | |
return false, terr(t1, "incompatible length, expected maximum length of " .. tostring(#t1.types) .. ", got " .. tostring(t2.inferred_len)) | |
end | |
local t1a, err = arraytype_from_tuple(t1.inferred_at, t1) | |
if not t1a then | |
return false, err | |
end | |
if not is_a(t1a, t2) then | |
return false, terr(t2, "got %s (from %s), expected %s", t1a, t1, t2) | |
end | |
return true | |
elseif t1.typename == "record" then | |
return match_fields_to_record(t1, t2) | |
elseif t1.typename == "arrayrecord" then | |
if not is_a(t1.elements, t2.elements) then | |
return false, terr(t1, "array parts have incompatible element types") | |
end | |
return match_fields_to_record(t1, t2) | |
elseif is_typetype(t1) and is_record_type(t1.def) then | |
return is_a(t1.def, t2, for_equality) | |
end | |
elseif t2.typename == "map" then | |
if t1.typename == "map" then | |
local _, errs_keys, errs_values | |
if t2.keys.typename ~= "any" then | |
_, errs_keys = same_type(t2.keys, t1.keys) | |
end | |
if t2.values.typename ~= "any" then | |
_, errs_values = same_type(t1.values, t2.values) | |
end | |
return combine_errs(errs_keys, errs_values) | |
elseif t1.typename == "array" or t1.typename == "tupletable" then | |
local elements | |
if t1.typename == "tupletable" then | |
local arr_type = arraytype_from_tuple(t1.inferred_at, t1) | |
if not arr_type then | |
return false, terr(t1, "Unable to convert tuple %s to map", t1) | |
end | |
elements = arr_type.elements | |
else | |
elements = t1.elements | |
end | |
local _, errs_keys, errs_values | |
_, errs_keys = is_a(INTEGER, t2.keys) | |
_, errs_values = is_a(elements, t2.values) | |
return combine_errs(errs_keys, errs_values) | |
elseif is_record_type(t1) then | |
if not is_a(t2.keys, STRING) then | |
return false, terr(t1, "can't match a record to a map with non-string keys") | |
end | |
if t2.keys.typename == "enum" then | |
for _, k in ipairs(t1.field_order) do | |
if not t2.keys.enumset[k] then | |
return false, terr(t1, "key is not an enum value: " .. k) | |
end | |
end | |
end | |
return match_fields_to_map(t1, t2) | |
end | |
elseif t2.typename == "tupletable" then | |
if t1.typename == "tupletable" then | |
for i = 1, math.min(#t1.types, #t2.types) do | |
if not is_a(t1.types[i], t2.types[i], for_equality) then | |
return false, terr(t1, "in tuple entry " .. tostring(i) .. ": got %s, expected %s", t1.types[i], t2.types[i]) | |
end | |
end | |
if for_equality and #t1.types ~= #t2.types then | |
return false, terr(t1, "tuples are not the same size") | |
end | |
if #t1.types > #t2.types then | |
return false, terr(t1, "tuple %s is too big for tuple %s", t1, t2) | |
end | |
return true | |
elseif is_array_type(t1) then | |
if t1.inferred_len and t1.inferred_len > #t2.types then | |
return false, terr(t1, "incompatible length, expected maximum length of " .. tostring(#t2.types) .. ", got " .. tostring(t1.inferred_len)) | |
end | |
local len = (t1.inferred_len and t1.inferred_len > 0) and | |
t1.inferred_len or | |
#t2.types | |
for i = 1, len do | |
if not is_a(t1.elements, t2.types[i], for_equality) then | |
return false, terr(t1, "tuple entry " .. tostring(i) .. " of type %s does not match type of array elements, which is %s", t2.types[i], t1.elements) | |
end | |
end | |
return true | |
end | |
elseif t1.typename == "function" and t2.typename == "function" then | |
local all_errs = {} | |
if (not t2.args.is_va) and #t1.args > #t2.args then | |
table.insert(all_errs, error_in_type(t1, "incompatible number of arguments: got " .. #t1.args .. " %s, expected " .. #t2.args .. " %s", t1.args, t2.args)) | |
else | |
for i = (t1.is_method and 2 or 1), #t1.args do | |
arg_check(is_a, t1.args[i], t2.args[i] or ANY, nil, i, all_errs) | |
end | |
end | |
local diff_by_va = #t2.rets - #t1.rets == 1 and t2.rets.is_va | |
if #t1.rets < #t2.rets and not diff_by_va then | |
table.insert(all_errs, error_in_type(t1, "incompatible number of returns: got " .. #t1.rets .. " %s, expected " .. #t2.rets .. " %s", t1.rets, t2.rets)) | |
else | |
local nrets = #t2.rets | |
if diff_by_va then | |
nrets = nrets - 1 | |
end | |
for i = 1, nrets do | |
local _, errs = is_a(t1.rets[i], t2.rets[i]) | |
add_errs_prefixing(errs, all_errs, "return " .. i .. ": ") | |
end | |
end | |
if #all_errs == 0 then | |
return true | |
else | |
return false, all_errs | |
end | |
elseif lax and ((not for_equality) and t2.typename == "boolean") then | |
return true | |
elseif t1.typename == t2.typename then | |
return true | |
end | |
return false, terr(t1, "got %s, expected %s", t1, t2) | |
end | |
local function assert_is_a(node, t1, t2, context, name) | |
t1 = resolve_tuple(t1) | |
t2 = resolve_tuple(t2) | |
if lax and (is_unknown(t1) or is_unknown(t2)) then | |
return true | |
end | |
if t1.typename == "nil" then | |
return true | |
elseif t2.typename == "unresolved_emptytable_value" then | |
if is_number_type(t2.emptytable_type.keys) then | |
infer_var(t2.emptytable_type, a_type({ typename = "array", elements = t1 }), node) | |
else | |
infer_var(t2.emptytable_type, a_type({ typename = "map", keys = t2.emptytable_type.keys, values = t1 }), node) | |
end | |
return true | |
elseif t2.typename == "emptytable" then | |
if is_known_table_type(t1) then | |
infer_var(t2, shallow_copy(t1), node) | |
elseif t1.typename ~= "emptytable" then | |
node_error(node, context .. ": " .. (name and (name .. ": ") or "") .. "assigning %s to a variable declared with {}", t1) | |
end | |
return true | |
end | |
local ok, match_errs = is_a(t1, t2) | |
add_errs_prefixing(match_errs, errors, context .. ": " .. (name and (name .. ": ") or ""), node) | |
return ok | |
end | |
local unknown_dots = {} | |
local function add_unknown_dot(node, name) | |
if not unknown_dots[name] then | |
unknown_dots[name] = true | |
add_unknown(node, name) | |
end | |
end | |
local type_check_function_call | |
do | |
local function resolve_for_call(node, func, args, is_method) | |
if lax and is_unknown(func) then | |
func = a_type({ typename = "function", args = VARARG({ UNKNOWN }), rets = VARARG({ UNKNOWN }) }) | |
if node.e1.op and node.e1.op.op == ":" and node.e1.e1.kind == "variable" then | |
add_unknown_dot(node, node.e1.e1.tk .. "." .. node.e1.e2.tk) | |
end | |
end | |
func = resolve_tuple_and_nominal(func) | |
if func.typename ~= "function" and func.typename ~= "poly" then | |
if is_typetype(func) and func.def.typename == "record" then | |
func = func.def | |
end | |
if func.meta_fields and func.meta_fields["__call"] then | |
table.insert(args, 1, func) | |
func = func.meta_fields["__call"] | |
is_method = true | |
end | |
end | |
return func, is_method | |
end | |
local function mark_invalid_typeargs(f) | |
if f.typeargs then | |
for _, a in ipairs(f.typeargs) do | |
if not find_var(a.typearg) then | |
add_var(nil, a.typearg, lax and UNKNOWN or INVALID) | |
end | |
end | |
end | |
end | |
local function try_match_func_args(node, f, args, argdelta) | |
local errs = {} | |
local given = #args | |
local expected = #f.args | |
local va = f.args.is_va | |
local nargs = va and | |
math.max(given, expected) or | |
math.min(given, expected) | |
for a = 1, nargs do | |
local argument = args[a] | |
local farg = f.args[a] or (va and f.args[expected]) | |
if argument == nil then | |
if va then | |
break | |
end | |
else | |
local at = node.e2 and node.e2[a] or node | |
if not arg_check(is_a, argument, farg, at, (a + argdelta), errs) then | |
return nil, errs | |
end | |
end | |
end | |
mark_invalid_typeargs(f) | |
for a = 1, given do | |
local argument = args[a] | |
if argument.typename == "emptytable" then | |
local farg = f.args[a] or (va and f.args[expected]) | |
local where = node.e2[a + argdelta] or node.e2 | |
infer_var(argument, resolve_typevars_at(farg, where), where) | |
end | |
end | |
return resolve_typevars_at(f.rets, node) | |
end | |
local function revert_typeargs(func) | |
if func.typeargs then | |
for _, fnarg in ipairs(func.typeargs) do | |
if st[#st][fnarg.typearg] then | |
st[#st][fnarg.typearg] = nil | |
end | |
end | |
end | |
end | |
local function fail_call(node, func, nargs, errs) | |
if errs then | |
for _, err in ipairs(errs) do | |
table.insert(errors, err) | |
end | |
else | |
local expects = {} | |
if func.typename == "poly" then | |
for _, f in ipairs(func.types) do | |
table.insert(expects, tostring(#f.args or 0)) | |
end | |
table.sort(expects) | |
for i = #expects, 1, -1 do | |
if expects[i] == expects[i + 1] then | |
table.remove(expects, i) | |
end | |
end | |
else | |
table.insert(expects, tostring(#func.args or 0)) | |
end | |
node_error(node, "wrong number of arguments (given " .. nargs .. ", expects " .. table.concat(expects, " or ") .. ")") | |
end | |
local f = func.typename == "poly" and func.types[1] or func | |
mark_invalid_typeargs(f) | |
return resolve_typevars_at(f.rets, node) | |
end | |
local function check_call(node, func, args, is_method, argdelta) | |
assert(type(func) == "table") | |
assert(type(args) == "table") | |
if not (func.typename == "function" or func.typename == "poly") then | |
func, is_method = resolve_for_call(node, func, args, is_method) | |
end | |
argdelta = is_method and -1 or argdelta or 0 | |
local is_func = func.typename == "function" | |
local is_poly = func.typename == "poly" | |
if not (is_func or is_poly) then | |
return node_error(node, "not a function: %s", func) | |
end | |
local passes, n = 1, 1 | |
if is_poly then | |
passes, n = 3, #func.types | |
end | |
local given = #args | |
local tried | |
local first_errs | |
for pass = 1, passes do | |
for i = 1, n do | |
if (not tried) or not tried[i] then | |
local f = is_func and func or func.types[i] | |
if f.is_method and not is_method and not (args[1] and is_a(args[1], f.args[1])) then | |
return node_error(node, "invoked method as a regular function: use ':' instead of '.'") | |
end | |
local expected = #f.args | |
if (is_func and (given <= expected or (f.args.is_va and given > expected))) or | |
(is_poly and ((pass == 1 and given == expected) or | |
(pass == 2 and given < expected) or | |
(pass == 3 and f.args.is_va and given > expected))) then | |
local matched, errs = try_match_func_args(node, f, args, argdelta) | |
if matched then | |
return matched | |
end | |
first_errs = first_errs or errs | |
if is_poly then | |
tried = tried or {} | |
tried[i] = true | |
revert_typeargs(f) | |
end | |
end | |
end | |
end | |
end | |
return fail_call(node, func, given, first_errs) | |
end | |
type_check_function_call = function(node, func, args, is_method, argdelta) | |
begin_scope() | |
local ret = check_call(node, func, args, is_method, argdelta) | |
end_scope() | |
return ret | |
end | |
end | |
local function match_record_key(node, tbl, key, orig_tbl) | |
assert(type(tbl) == "table") | |
assert(type(key) == "table") | |
tbl = resolve_tuple_and_nominal(tbl) | |
local type_description = tbl.typename | |
if tbl.typename == "string" or tbl.typename == "enum" then | |
tbl = find_var_type("string") | |
end | |
if lax and (is_unknown(tbl) or tbl.typename == "typevar") then | |
if node.e1.kind == "variable" and node.op.op ~= "@funcall" then | |
add_unknown_dot(node, node.e1.tk .. "." .. key.tk) | |
end | |
return UNKNOWN | |
end | |
if tbl.is_alias then | |
return node_error(key, "cannot use a nested type alias as a concrete value") | |
end | |
tbl = resolve_typetype(tbl) | |
if tbl.typename == "emptytable" then | |
elseif is_record_type(tbl) then | |
assert(tbl.fields, "record has no fields!?") | |
if key.kind == "string" or key.kind == "identifier" then | |
if tbl.fields[key.tk] then | |
return tbl.fields[key.tk] | |
end | |
end | |
else | |
if is_unknown(tbl) then | |
if not lax then | |
node_error(key, "cannot index a value of unknown type") | |
end | |
else | |
node_error(key, "cannot index something that is not a record: %s", tbl) | |
end | |
return INVALID | |
end | |
if lax then | |
if node.e1.kind == "variable" and node.op.op ~= "@funcall" then | |
add_unknown_dot(node, node.e1.tk .. "." .. key.tk) | |
end | |
return UNKNOWN | |
end | |
local description | |
if node.e1.kind == "variable" then | |
description = type_description .. " '" .. node.e1.tk .. "' of type " .. show_type(resolve_tuple(orig_tbl)) | |
else | |
description = "type " .. show_type(resolve_tuple(orig_tbl)) | |
end | |
return node_error(key, "invalid key '" .. key.tk .. "' in " .. description) | |
end | |
local function widen_in_scope(scope, var) | |
if scope[var].is_narrowed then | |
if scope[var].narrowed_from then | |
scope[var].t = scope[var].narrowed_from | |
scope[var].narrowed_from = nil | |
scope[var].is_narrowed = false | |
else | |
scope[var] = nil | |
end | |
return true | |
end | |
return false | |
end | |
local function widen_back_var(var) | |
local widened = false | |
for i = #st, 1, -1 do | |
if st[i][var] then | |
if widen_in_scope(st[i], var) then | |
widened = true | |
else | |
break | |
end | |
end | |
end | |
return widened | |
end | |
local function widen_all_unions() | |
for i = #st, 1, -1 do | |
for var, _ in pairs(st[i]) do | |
widen_in_scope(st[i], var) | |
end | |
end | |
end | |
local function add_global(node, var, valtype, is_const) | |
if lax and is_unknown(valtype) and (var ~= "self" and var ~= "...") then | |
add_unknown(node, var) | |
end | |
st[1][var] = { t = valtype, is_const = is_const } | |
if node then | |
node.type = node.type or valtype | |
end | |
end | |
local function get_rets(rets) | |
if lax and (#rets == 0) then | |
return VARARG({ UNKNOWN }) | |
end | |
local t = rets | |
if not t.typename then | |
t = TUPLE(t) | |
end | |
assert(t.typeid) | |
return t | |
end | |
local function add_internal_function_variables(node) | |
add_var(nil, "@is_va", node.args.type.is_va and ANY or NIL) | |
add_var(nil, "@return", node.rets or a_type({ typename = "tuple" })) | |
end | |
local function add_function_definition_for_recursion(node) | |
local args = a_type({ typename = "tuple" }) | |
for _, fnarg in ipairs(node.args) do | |
table.insert(args, fnarg.type) | |
end | |
add_var(nil, node.name.tk, a_type({ | |
typename = "function", | |
args = args, | |
rets = get_rets(node.rets), | |
})) | |
end | |
local function fail_unresolved() | |
local unresolved = st[#st]["@unresolved"] | |
if unresolved then | |
st[#st]["@unresolved"] = nil | |
for name, nodes in pairs(unresolved.t.labels) do | |
for _, node in ipairs(nodes) do | |
node_error(node, "no visible label '" .. name .. "' for goto") | |
end | |
end | |
for _, types in pairs(unresolved.t.nominals) do | |
for _, typ in ipairs(types) do | |
assert(typ.x) | |
assert(typ.y) | |
type_error(typ, "unknown type %s", typ) | |
end | |
end | |
end | |
end | |
local function end_function_scope(node) | |
fail_unresolved() | |
end_scope(node) | |
end | |
resolve_tuple_and_nominal = function(t) | |
t = resolve_tuple(t) | |
if t.typename == "nominal" then | |
t = resolve_nominal(t) | |
end | |
assert(t.typename ~= "nominal") | |
return t | |
end | |
local function flatten_list(list) | |
local exps = {} | |
for i = 1, #list - 1 do | |
table.insert(exps, resolve_tuple_and_nominal(list[i])) | |
end | |
if #list > 0 then | |
local last = list[#list] | |
if last.typename == "tuple" then | |
for _, val in ipairs(last) do | |
table.insert(exps, val) | |
end | |
else | |
table.insert(exps, last) | |
end | |
end | |
return exps | |
end | |
local function get_assignment_values(vals, wanted) | |
local ret = {} | |
if vals == nil then | |
return ret | |
end | |
local is_va = vals.is_va | |
for i = 1, #vals - 1 do | |
ret[i] = resolve_tuple(vals[i]) | |
end | |
local last = vals[#vals] | |
if last.typename == "tuple" then | |
is_va = last.is_va | |
for _, v in ipairs(last) do | |
table.insert(ret, v) | |
end | |
else | |
table.insert(ret, last) | |
end | |
if is_va and last and #ret < wanted then | |
while #ret < wanted do | |
table.insert(ret, last) | |
end | |
end | |
return ret | |
end | |
local function match_all_record_field_names(node, a, field_names, errmsg) | |
local t | |
for _, k in ipairs(field_names) do | |
local f = a.fields[k] | |
if not t then | |
t = f | |
else | |
if not same_type(f, t) then | |
t = nil | |
break | |
end | |
end | |
end | |
if t then | |
return t | |
else | |
return node_error(node, errmsg) | |
end | |
end | |
local function type_check_index(node, idxnode, a, b) | |
local orig_a = a | |
local orig_b = b | |
a = resolve_tuple_and_nominal(a) | |
b = resolve_tuple_and_nominal(b) | |
if a.typename == "tupletable" and is_a(b, INTEGER) then | |
if idxnode.constnum then | |
if idxnode.constnum > #a.types or | |
idxnode.constnum < 1 or | |
idxnode.constnum ~= math.floor(idxnode.constnum) then | |
return node_error(idxnode, "index " .. tostring(idxnode.constnum) .. " out of range for tuple %s", a) | |
end | |
return a.types[idxnode.constnum] | |
else | |
local array_type = arraytype_from_tuple(idxnode, a) | |
if not array_type then | |
type_error(a, "cannot index this tuple with a variable because it would produce a union type that cannot be discriminated at runtime") | |
return INVALID | |
end | |
return array_type.elements | |
end | |
elseif is_array_type(a) and is_a(b, INTEGER) then | |
return a.elements | |
elseif a.typename == "emptytable" then | |
if a.keys == nil then | |
a.keys = resolve_tuple(orig_b) | |
a.keys_inferred_at = assert(node) | |
a.keys_inferred_at_file = filename | |
else | |
if not is_a(b, a.keys) then | |
local inferred = " (type of keys inferred at " .. a.keys_inferred_at_file .. ":" .. a.keys_inferred_at.y .. ":" .. a.keys_inferred_at.x .. ": )" | |
return node_error(idxnode, "inconsistent index type: %s, expected %s" .. inferred, orig_b, a.keys) | |
end | |
end | |
return a_type({ y = node.y, x = node.x, typename = "unresolved_emptytable_value", emptytable_type = a }) | |
elseif a.typename == "map" then | |
if is_a(b, a.keys) then | |
return a.values | |
else | |
return node_error(idxnode, "wrong index type: %s, expected %s", orig_b, a.keys) | |
end | |
elseif node.e2.kind == "string" or node.e2.kind == "enum_item" then | |
return match_record_key(node, a, { y = node.e2.y, x = node.e2.x, kind = "string", tk = assert(node.e2.conststr) }, orig_a) | |
elseif is_record_type(a) then | |
if b.typename == "enum" then | |
local field_names = sorted_keys(b.enumset) | |
for _, k in ipairs(field_names) do | |
if not a.fields[k] then | |
return node_error(idxnode, "enum value '" .. k .. "' is not a field in %s", a) | |
end | |
end | |
return match_all_record_field_names(idxnode, a, field_names, | |
"cannot index, not all enum values map to record fields of the same type") | |
elseif is_a(b, STRING) then | |
return node_error(idxnode, "cannot index object of type %s with a string, consider using an enum", orig_a) | |
end | |
end | |
if lax and is_unknown(a) then | |
return UNKNOWN | |
else | |
return node_error(idxnode, "cannot index object of type %s with %s", orig_a, orig_b) | |
end | |
end | |
expand_type = function(where, old, new) | |
if not old or old.typename == "nil" then | |
return new | |
else | |
if not is_a(new, old) then | |
if old.typename == "map" and is_record_type(new) then | |
if old.keys.typename == "string" then | |
for _, ftype in fields_of(new) do | |
old.values = expand_type(where, old.values, ftype) | |
end | |
else | |
node_error(where, "cannot determine table literal type") | |
end | |
elseif is_record_type(old) and is_record_type(new) then | |
old.typename = "map" | |
old.keys = STRING | |
for _, ftype in fields_of(old) do | |
if not old.values then | |
old.values = ftype | |
else | |
old.values = expand_type(where, old.values, ftype) | |
end | |
end | |
for _, ftype in fields_of(new) do | |
if not old.values then | |
new.values = ftype | |
else | |
new.values = expand_type(where, old.values, ftype) | |
end | |
end | |
old.fields = nil | |
old.field_order = nil | |
elseif old.typename == "union" then | |
new.tk = nil | |
table.insert(old.types, new) | |
else | |
old.tk = nil | |
new.tk = nil | |
return unite({ old, new }) | |
end | |
end | |
end | |
return old | |
end | |
local function find_record_to_extend(exp) | |
if exp.kind == "type_identifier" then | |
local t = find_var_type(exp.tk) | |
if t.def then | |
if not t.def.closed and not t.closed then | |
return t.def | |
end | |
end | |
if not t.closed then | |
return t | |
end | |
elseif exp.kind == "op" and exp.op.op == "." then | |
local t = find_record_to_extend(exp.e1) | |
if not t then | |
return nil | |
end | |
while exp.e2.kind == "op" and exp.e2.op.op == "." do | |
t = t.fields and t.fields[exp.e2.e1.tk] | |
if not t then | |
return nil | |
end | |
exp = exp.e2 | |
end | |
t = t.fields and t.fields[exp.e2.tk] | |
return t | |
end | |
end | |
local facts_and | |
local facts_or | |
local facts_not | |
local apply_facts | |
local FACT_TRUTHY | |
do | |
setmetatable(Fact, { | |
__call = function(_, fact) | |
return setmetatable(fact, { | |
__tostring = function(f) | |
if f.fact == "is" then | |
return ("(%s is %s)"):format(f.var, show_type(f.typ)) | |
elseif f.fact == "==" then | |
return ("(%s == %s)"):format(f.var, show_type(f.typ)) | |
elseif f.fact == "truthy" then | |
return "*" | |
elseif f.fact == "not" then | |
return ("(not %s)"):format(tostring(f.f1)) | |
elseif f.fact == "or" then | |
return ("(%s or %s)"):format(tostring(f.f1), tostring(f.f2)) | |
elseif f.fact == "and" then | |
return ("(%s and %s)"):format(tostring(f.f1), tostring(f.f2)) | |
end | |
end, | |
}) | |
end, | |
}) | |
FACT_TRUTHY = Fact({ fact = "truthy" }) | |
facts_and = function(f1, f2, where) | |
return Fact({ fact = "and", f1 = f1, f2 = f2, where = where }) | |
end | |
facts_or = function(f1, f2, where) | |
if f1 and f2 then | |
return Fact({ fact = "or", f1 = f1, f2 = f2, where = where }) | |
else | |
return nil | |
end | |
end | |
facts_not = function(f1, where) | |
if f1 then | |
return Fact({ fact = "not", f1 = f1, where = where }) | |
else | |
return nil | |
end | |
end | |
local function unite_types(t1, t2) | |
return unite({ t2, t1 }) | |
end | |
local function intersect_types(t1, t2) | |
if t2.typename == "union" then | |
t1, t2 = t2, t1 | |
end | |
if t1.typename == "union" then | |
local out = {} | |
for _, t in ipairs(t1.types) do | |
if is_a(t, t2) then | |
table.insert(out, t) | |
end | |
end | |
return unite(out) | |
else | |
if is_a(t1, t2) then | |
return t1 | |
elseif is_a(t2, t1) then | |
return t2 | |
else | |
return INVALID | |
end | |
end | |
end | |
local function resolve_if_union(t) | |
local rt = resolve_tuple_and_nominal(t) | |
if rt.typename == "union" then | |
return rt | |
end | |
return t | |
end | |
local function subtract_types(t1, t2) | |
local types = {} | |
t1 = resolve_if_union(t1) | |
if t1.typename ~= "union" then | |
return t1 | |
end | |
t2 = resolve_if_union(t2) | |
local t2types = t2.types or { t2 } | |
for _, at in ipairs(t1.types) do | |
local not_present = true | |
for _, bt in ipairs(t2types) do | |
if same_type(at, bt) then | |
not_present = false | |
break | |
end | |
end | |
if not_present then | |
table.insert(types, at) | |
end | |
end | |
if #types == 0 then | |
return INVALID | |
end | |
return unite(types) | |
end | |
local eval_not | |
local not_facts | |
local or_facts | |
local and_facts | |
local eval_fact | |
local function invalid_from(f) | |
return Fact({ fact = "is", var = f.var, typ = INVALID, where = f.where }) | |
end | |
not_facts = function(fs) | |
local ret = {} | |
for var, f in pairs(fs) do | |
local typ = find_var_type(f.var, true) | |
local fact = "==" | |
local where = f.where | |
if not typ then | |
typ = INVALID | |
else | |
if f.fact == "is" then | |
if typ.typename == "typevar" then | |
where = nil | |
elseif not is_a(f.typ, typ) then | |
node_warning("branch", f.where, f.var .. " (of type %s) can never be a %s", show_type(typ), show_type(f.typ)) | |
typ = INVALID | |
else | |
fact = "is" | |
typ = subtract_types(typ, f.typ) | |
end | |
elseif f.fact == "==" then | |
where = nil | |
end | |
end | |
ret[var] = Fact({ fact = fact, var = var, typ = typ, where = where }) | |
end | |
return ret | |
end | |
eval_not = function(f) | |
if not f then | |
return {} | |
elseif f.fact == "is" then | |
return not_facts({ [f.var] = f }) | |
elseif f.fact == "not" then | |
return eval_fact(f.f1) | |
elseif f.fact == "and" and f.f2 and f.f2.fact == "truthy" then | |
return eval_not(f.f1) | |
elseif f.fact == "or" and f.f2 and f.f2.fact == "truthy" then | |
return eval_fact(f.f1) | |
elseif f.fact == "and" then | |
return or_facts(not_facts(eval_fact(f.f1)), not_facts(eval_fact(f.f2))) | |
elseif f.fact == "or" then | |
return and_facts(not_facts(eval_fact(f.f1)), not_facts(eval_fact(f.f2))) | |
else | |
return not_facts(eval_fact(f)) | |
end | |
end | |
or_facts = function(fs1, fs2) | |
local ret = {} | |
for var, f in pairs(fs2) do | |
if fs1[var] then | |
local fact = (fs1[var].fact == "is" and f.fact == "is") and | |
"is" or "==" | |
ret[var] = Fact({ fact = fact, var = var, typ = unite_types(f.typ, fs1[var].typ), where = f.where }) | |
end | |
end | |
return ret | |
end | |
and_facts = function(fs1, fs2) | |
local ret = {} | |
local has = {} | |
for var, f in pairs(fs1) do | |
local rt | |
local fact | |
if fs2[var] then | |
fact = (fs2[var].fact == "is" and f.fact == "is") and "is" or "==" | |
rt = intersect_types(f.typ, fs2[var].typ) | |
else | |
fact = "==" | |
rt = f.typ | |
end | |
ret[var] = Fact({ fact = fact, var = var, typ = rt, where = f.where }) | |
has[fact] = true | |
end | |
for var, f in pairs(fs2) do | |
if not fs1[var] then | |
ret[var] = Fact({ fact = "==", var = var, typ = f.typ, where = f.where }) | |
has["=="] = true | |
end | |
end | |
if has["is"] and has["=="] then | |
for _, f in pairs(ret) do | |
f.fact = "==" | |
end | |
end | |
return ret | |
end | |
eval_fact = function(f) | |
if not f then | |
return {} | |
elseif f.fact == "is" then | |
local typ = find_var_type(f.var, true) | |
if not typ then | |
return { [f.var] = invalid_from(f) } | |
end | |
if typ.typename ~= "typevar" and is_a(typ, f.typ) then | |
node_warning("branch", f.where, f.var .. " (of type %s) is always a %s", show_type(typ), show_type(f.typ)) | |
return { [f.var] = f } | |
elseif typ.typename ~= "typevar" and not is_a(f.typ, typ) then | |
node_error(f.where, f.var .. " (of type %s) can never be a %s", typ, f.typ) | |
return { [f.var] = invalid_from(f) } | |
else | |
return { [f.var] = f } | |
end | |
elseif f.fact == "==" then | |
return { [f.var] = f } | |
elseif f.fact == "not" then | |
return eval_not(f.f1) | |
elseif f.fact == "truthy" then | |
return {} | |
elseif f.fact == "and" and f.f2 and f.f2.fact == "truthy" then | |
return eval_fact(f.f1) | |
elseif f.fact == "or" and f.f2 and f.f2.fact == "truthy" then | |
return eval_not(f.f1) | |
elseif f.fact == "and" then | |
return and_facts(eval_fact(f.f1), eval_fact(f.f2)) | |
elseif f.fact == "or" then | |
return or_facts(eval_fact(f.f1), eval_fact(f.f2)) | |
end | |
end | |
apply_facts = function(where, known) | |
if not known then | |
return | |
end | |
local facts = eval_fact(known) | |
for v, f in pairs(facts) do | |
if f.typ.typename == "invalid" then | |
node_error(where, "cannot resolve a type for " .. v .. " here") | |
end | |
local t = shallow_copy(f.typ) | |
t.inferred_at = f.where and where | |
t.inferred_at_file = filename | |
add_var(nil, v, t, true, true) | |
end | |
end | |
end | |
local function dismiss_unresolved(name) | |
local unresolved = st[#st]["@unresolved"] | |
if unresolved then | |
if unresolved.t.nominals[name] then | |
for _, t in ipairs(unresolved.t.nominals[name]) do | |
resolve_nominal(t) | |
end | |
end | |
unresolved.t.nominals[name] = nil | |
end | |
end | |
local type_check_funcall | |
local function special_pcall_xpcall(node, _a, b, argdelta) | |
local base_nargs = (node.e1.tk == "xpcall") and 2 or 1 | |
if #node.e2 < base_nargs then | |
node_error(node, "wrong number of arguments (given " .. #node.e2 .. ", expects at least " .. base_nargs .. ")") | |
return TUPLE({ BOOLEAN }) | |
end | |
local ftype = table.remove(b, 1) | |
local fe2 = {} | |
if node.e1.tk == "xpcall" then | |
base_nargs = 2 | |
local msgh = table.remove(b, 1) | |
assert_is_a(node.e2[2], msgh, XPCALL_MSGH_FUNCTION, "in message handler") | |
end | |
for i = base_nargs + 1, #node.e2 do | |
table.insert(fe2, node.e2[i]) | |
end | |
local fnode = { | |
y = node.y, | |
x = node.x, | |
kind = "op", | |
op = { op = "@funcall" }, | |
e1 = node.e2[1], | |
e2 = fe2, | |
} | |
local rets = type_check_funcall(fnode, ftype, b, argdelta + base_nargs) | |
if rets.typename ~= "tuple" then | |
rets = a_type({ typename = "tuple", rets }) | |
end | |
table.insert(rets, 1, BOOLEAN) | |
return rets | |
end | |
local special_functions = { | |
["rawget"] = function(node, _a, b, _argdelta) | |
if #b == 2 then | |
local b1 = resolve_tuple_and_nominal(b[1]) | |
local b2 = resolve_tuple_and_nominal(b[2]) | |
local knode = node.e2[2] | |
if is_record_type(b1) and knode.conststr then | |
return match_record_key(node, b1, { y = knode.y, x = knode.x, kind = "string", tk = assert(knode.conststr) }, b1) | |
else | |
return type_check_index(node, knode, b1, b2) | |
end | |
else | |
node_error(node, "rawget expects two arguments") | |
return INVALID | |
end | |
end, | |
["print_type"] = function(node, _a, b, _argdelta) | |
if #b == 0 then | |
print("-----------------------------------------") | |
for i, scope in ipairs(st) do | |
for s, v in pairs(scope) do | |
print(("%2d %-14s %-11s %s"):format(i, s, v.t.typename, show_type(v.t):sub(1, 50))) | |
end | |
end | |
print("-----------------------------------------") | |
return NONE | |
else | |
local t = show_type(b[1]) | |
print(t) | |
node_warning("debug", node.e2[1], "type is: %s", t) | |
return b | |
end | |
end, | |
["require"] = function(node, _a, b, _argdelta) | |
if #b ~= 1 then | |
return node_error(node, "require expects one literal argument") | |
end | |
if node.e2[1].kind ~= "string" then | |
return node_error(node, "don't know how to resolve a dynamic require") | |
end | |
local module_name = assert(node.e2[1].conststr) | |
local t, found = require_module(module_name, lax, env) | |
if not found then | |
return node_error(node, "module not found: '" .. module_name .. "'") | |
end | |
if t.typename == "invalid" then | |
if lax then | |
return UNKNOWN | |
end | |
return node_error(node, "no type information for required module: '" .. module_name .. "'") | |
end | |
dependencies[module_name] = t.filename | |
return t | |
end, | |
["pcall"] = special_pcall_xpcall, | |
["xpcall"] = special_pcall_xpcall, | |
["assert"] = function(node, a, b, argdelta) | |
node.known = FACT_TRUTHY | |
return type_check_function_call(node, a, b, false, argdelta) | |
end, | |
} | |
type_check_funcall = function(node, a, b, argdelta) | |
argdelta = argdelta or 0 | |
if node.e1.kind == "variable" then | |
local special = special_functions[node.e1.tk] | |
if special then | |
return special(node, a, b, argdelta) | |
else | |
return type_check_function_call(node, a, b, false, argdelta) | |
end | |
elseif node.e1.op and node.e1.op.op == ":" then | |
table.insert(b, 1, node.e1.e1.type) | |
return type_check_function_call(node, a, b, true) | |
else | |
return type_check_function_call(node, a, b, false, argdelta) | |
end | |
end | |
local function is_localizing_a_variable(node, i) | |
return node.exps and | |
node.exps[i] and | |
node.exps[i].kind == "variable" and | |
node.exps[i].tk == node.vars[i].tk | |
end | |
local function resolve_nominal_typetype(typetype) | |
if typetype.def.typename == "nominal" then | |
if typetype.def.typevals then | |
typetype.def = resolve_nominal(typetype.def) | |
typetype.def.typeargs = nil | |
else | |
local names = typetype.def.names | |
local found = find_type(names) | |
if (not found) or (not is_typetype(found)) then | |
type_error(typetype, "%s is not a type", typetype) | |
found = a_type({ typename = "bad_nominal", names = names }) | |
end | |
return found, true | |
end | |
end | |
return typetype, false | |
end | |
local function missing_initializer(node, i, name) | |
if lax then | |
return UNKNOWN | |
else | |
if node.exps then | |
node_error(node.vars[i], "assignment in declaration did not produce an initial value for variable '" .. name .. "'") | |
else | |
node_error(node.vars[i], "variable '" .. name .. "' has no type or initial value") | |
end | |
return INVALID | |
end | |
end | |
local function set_expected_types_to_decltypes(node, children) | |
local decls = node.kind == "assignment" and children[1] or node.decltype | |
if decls and node.exps then | |
local ndecl = #decls | |
local nexps = #node.exps | |
for i = 1, nexps do | |
local typ | |
typ = decls[i] | |
if typ then | |
if i == nexps and ndecl > nexps then | |
typ = a_type({ y = node.y, x = node.x, filename = filename, typename = "tuple", types = {} }) | |
for a = i, ndecl do | |
table.insert(typ.types, decls[a]) | |
end | |
end | |
node.exps[i].expected = typ | |
node.exps[i].expected_context = { kind = node.kind, name = node.vars[i].tk } | |
end | |
end | |
end | |
end | |
local function is_positive_int(n) | |
return n and n >= 1 and math.floor(n) == n | |
end | |
local context_name = { | |
["local_declaration"] = "in local declaration", | |
["global_declaration"] = "in global declaration", | |
["assignment"] = "in assignment", | |
} | |
local function in_context(ctx, msg) | |
if not ctx then | |
return msg | |
end | |
local where = context_name[ctx.kind] | |
if where then | |
return where .. ": " .. (ctx.name and ctx.name .. ": " or "") .. msg | |
else | |
return msg | |
end | |
end | |
local function check_redeclared_key(ctx, where, seen_keys, ck, n) | |
local key = ck or n | |
if key then | |
local s = seen_keys[key] | |
if s then | |
node_error(where, in_context(ctx, "redeclared key " .. tostring(key) .. " (previously declared at " .. filename .. ":" .. s.y .. ":" .. s.x .. ")")) | |
else | |
seen_keys[key] = where | |
end | |
end | |
end | |
local function infer_table_literal(node, children) | |
local typ = a_type({ | |
filename = filename, | |
y = node.y, | |
x = node.x, | |
typename = "emptytable", | |
}) | |
local is_record = false | |
local is_array = false | |
local is_map = false | |
local is_tuple = false | |
local is_not_tuple = false | |
local last_array_idx = 1 | |
local largest_array_idx = -1 | |
local seen_keys = {} | |
for i, child in ipairs(children) do | |
assert(child.typename == "table_item") | |
local ck = child.kname | |
local n = node[i].key.constnum | |
check_redeclared_key(nil, node[i], seen_keys, ck, n) | |
local uvtype = resolve_tuple(child.vtype) | |
if ck then | |
is_record = true | |
if not typ.fields then | |
typ.fields = {} | |
typ.field_order = {} | |
end | |
typ.fields[ck] = uvtype | |
table.insert(typ.field_order, ck) | |
elseif is_number_type(child.ktype) then | |
is_array = true | |
if not is_not_tuple then | |
is_tuple = true | |
end | |
if not typ.types then | |
typ.types = {} | |
end | |
if node[i].key_parsed == "implicit" then | |
if i == #children and child.vtype.typename == "tuple" then | |
for _, c in ipairs(child.vtype) do | |
typ.elements = expand_type(node, typ.elements, c) | |
typ.types[last_array_idx] = resolve_tuple(c) | |
last_array_idx = last_array_idx + 1 | |
end | |
else | |
typ.types[last_array_idx] = uvtype | |
last_array_idx = last_array_idx + 1 | |
typ.elements = expand_type(node, typ.elements, uvtype) | |
end | |
else | |
if not is_positive_int(n) then | |
typ.elements = expand_type(node, typ.elements, uvtype) | |
is_not_tuple = true | |
elseif n then | |
typ.types[n] = uvtype | |
if n > largest_array_idx then | |
largest_array_idx = n | |
end | |
typ.elements = expand_type(node, typ.elements, uvtype) | |
end | |
end | |
if last_array_idx > largest_array_idx then | |
largest_array_idx = last_array_idx | |
end | |
if not typ.elements then | |
is_array = false | |
end | |
else | |
is_map = true | |
child.ktype.tk = nil | |
typ.keys = expand_type(node, typ.keys, child.ktype) | |
typ.values = expand_type(node, typ.values, uvtype) | |
end | |
end | |
if is_array and is_map then | |
typ.typename = "map" | |
typ.keys = expand_type(node, typ.keys, INTEGER) | |
typ.values = expand_type(node, typ.values, typ.elements) | |
typ.elements = nil | |
node_error(node, "cannot determine type of table literal") | |
elseif is_record and is_array then | |
typ.typename = "arrayrecord" | |
elseif is_record and is_map then | |
if typ.keys.typename == "string" then | |
typ.typename = "map" | |
for _, ftype in fields_of(typ) do | |
typ.values = expand_type(node, typ.values, ftype) | |
end | |
typ.fields = nil | |
typ.field_order = nil | |
else | |
node_error(node, "cannot determine type of table literal") | |
end | |
elseif is_array then | |
if is_not_tuple then | |
typ.typename = "array" | |
typ.inferred_len = largest_array_idx - 1 | |
else | |
local pure_array = true | |
local last_t | |
for _, current_t in pairs(typ.types) do | |
if last_t then | |
if not same_type(last_t, current_t) then | |
pure_array = false | |
break | |
end | |
end | |
last_t = current_t | |
end | |
if not pure_array then | |
typ.typename = "tupletable" | |
else | |
typ.typename = "array" | |
typ.inferred_len = largest_array_idx - 1 | |
end | |
end | |
elseif is_record then | |
typ.typename = "record" | |
elseif is_map then | |
typ.typename = "map" | |
elseif is_tuple then | |
typ.typename = "tupletable" | |
if not typ.types or #typ.types == 0 then | |
node_error(node, "cannot determine type of tuple elements") | |
end | |
end | |
return typ | |
end | |
local visit_node = {} | |
visit_node.cbs = { | |
["statements"] = { | |
before = function(node) | |
begin_scope(node) | |
end, | |
after = function(node, _children) | |
if #st == 2 then | |
fail_unresolved() | |
end | |
if not node.is_repeat then | |
end_scope(node) | |
end | |
node.type = NONE | |
return node.type | |
end, | |
}, | |
["local_type"] = { | |
before = function(node) | |
node.value.type, node.value.is_alias = resolve_nominal_typetype(node.value.newtype) | |
add_var(node.var, node.var.tk, node.value.type, node.var.is_const) | |
end, | |
after = function(node, _children) | |
dismiss_unresolved(node.var.tk) | |
node.type = NONE | |
return node.type | |
end, | |
}, | |
["global_type"] = { | |
before = function(node) | |
node.value.newtype, node.value.is_alias = resolve_nominal_typetype(node.value.newtype) | |
add_global(node.var, node.var.tk, node.value.newtype, node.var.is_const) | |
end, | |
after = function(node, _children) | |
local existing, existing_is_const = find_global(node.var.tk) | |
local var = node.var | |
if existing then | |
if existing_is_const == true and not var.is_const then | |
node_error(var, "global was previously declared as <const>: " .. var.tk) | |
end | |
if existing_is_const == false and var.is_const then | |
node_error(var, "global was previously declared as not <const>: " .. var.tk) | |
end | |
if not same_type(existing, node.value.newtype) then | |
node_error(var, "cannot redeclare global with a different type: previous type of " .. var.tk .. " is %s", existing) | |
end | |
end | |
dismiss_unresolved(var.tk) | |
node.type = NONE | |
return node.type | |
end, | |
}, | |
["local_declaration"] = { | |
before = function(node) | |
for _, var in ipairs(node.vars) do | |
reserve_symbol_list_slot(var) | |
end | |
end, | |
before_expressions = set_expected_types_to_decltypes, | |
after = function(node, children) | |
local vals = get_assignment_values(children[3], #node.vars) | |
for i, var in ipairs(node.vars) do | |
local decltype = node.decltype and node.decltype[i] | |
local infertype = vals and vals[i] | |
if lax and infertype and infertype.typename == "nil" then | |
infertype = nil | |
end | |
if decltype and infertype then | |
assert_is_a(node.vars[i], infertype, decltype, "in local declaration", var.tk) | |
end | |
local t = decltype or infertype | |
if t == nil then | |
t = missing_initializer(node, i, var.tk) | |
elseif t.typename == "emptytable" then | |
t.declared_at = node | |
t.assigned_to = var.tk | |
end | |
t.inferred_len = nil | |
assert(var) | |
add_var(var, var.tk, t, var.is_const, is_localizing_a_variable(node, i)) | |
dismiss_unresolved(var.tk) | |
end | |
node.type = NONE | |
return node.type | |
end, | |
}, | |
["global_declaration"] = { | |
before_expressions = set_expected_types_to_decltypes, | |
after = function(node, children) | |
local vals = get_assignment_values(children[3], #node.vars) | |
for i, var in ipairs(node.vars) do | |
local decltype = node.decltype and node.decltype[i] | |
local infertype = vals and vals[i] | |
if lax and infertype and infertype.typename == "nil" then | |
infertype = nil | |
end | |
if decltype and infertype then | |
assert_is_a(node.vars[i], infertype, decltype, "in global declaration", var.tk) | |
end | |
local t = decltype or infertype | |
local existing, existing_is_const = find_global(var.tk) | |
if existing then | |
if infertype and existing_is_const then | |
node_error(var, "cannot reassign to <const> global: " .. var.tk) | |
end | |
if existing_is_const == true and not var.is_const then | |
node_error(var, "global was previously declared as <const>: " .. var.tk) | |
end | |
if existing_is_const == false and var.is_const then | |
node_error(var, "global was previously declared as not <const>: " .. var.tk) | |
end | |
if t and not same_type(existing, t) then | |
node_error(var, "cannot redeclare global with a different type: previous type of " .. var.tk .. " is %s", existing) | |
end | |
else | |
if t == nil then | |
t = missing_initializer(node, i, var.tk) | |
elseif t.typename == "emptytable" then | |
t.declared_at = node | |
t.assigned_to = var.tk | |
end | |
t.inferred_len = nil | |
add_global(var, var.tk, t, var.is_const) | |
var.type = t | |
dismiss_unresolved(var.tk) | |
end | |
end | |
node.type = NONE | |
return node.type | |
end, | |
}, | |
["assignment"] = { | |
before_expressions = set_expected_types_to_decltypes, | |
after = function(node, children) | |
local vals = get_assignment_values(children[3], #children[1]) | |
local exps = flatten_list(vals) | |
for i, vartype in ipairs(children[1]) do | |
local varnode = node.vars[i] | |
local is_const = varnode.is_const | |
if varnode.kind == "variable" then | |
if widen_back_var(varnode.tk) then | |
vartype, is_const = find_var_type(varnode.tk) | |
end | |
end | |
if is_const then | |
node_error(varnode, "cannot assign to <const> variable") | |
end | |
if vartype then | |
local val = exps[i] | |
if is_typetype(resolve_tuple_and_nominal(vartype)) then | |
node_error(varnode, "cannot reassign a type") | |
elseif val then | |
assert_is_a(varnode, val, vartype, "in assignment") | |
if varnode.kind == "variable" and vartype.typename == "union" then | |
add_var(varnode, varnode.tk, val, false, true) | |
end | |
else | |
node_error(varnode, "variable is not being assigned a value") | |
if #node.exps == 1 and node.exps[1].kind == "op" and node.exps[1].op.op == "@funcall" then | |
local rets = node.exps[1].type | |
if rets.typename == "tuple" then | |
local msg = #rets == 1 and | |
"only 1 value is returned by the function" or | |
("only " .. #rets .. " values are returned by the function") | |
node_warning("hint", varnode, msg) | |
end | |
end | |
end | |
else | |
node_error(varnode, "unknown variable") | |
end | |
end | |
node.type = NONE | |
return node.type | |
end, | |
}, | |
["if"] = { | |
after = function(node, _children) | |
node.type = NONE | |
return node.type | |
end, | |
}, | |
["if_block"] = { | |
before = function(node) | |
begin_scope(node) | |
if node.if_block_n > 1 then | |
local ifnode = node.if_parent | |
local f = facts_not(ifnode.if_blocks[1].exp.known, node) | |
for e = 2, node.if_block_n - 1 do | |
f = facts_and(f, facts_not(ifnode.if_blocks[e].exp.known, node), node) | |
end | |
apply_facts(node, f) | |
end | |
end, | |
before_statements = function(node) | |
if node.exp then | |
apply_facts(node.exp, node.exp.known) | |
end | |
end, | |
after = end_scope_and_none_type, | |
}, | |
["while"] = { | |
before = function() | |
widen_all_unions() | |
end, | |
before_statements = function(node) | |
begin_scope(node) | |
apply_facts(node.exp, node.exp.known) | |
end, | |
after = end_scope_and_none_type, | |
}, | |
["label"] = { | |
before = function(node) | |
widen_all_unions() | |
local label_id = "::" .. node.label .. "::" | |
if st[#st][label_id] then | |
node_error(node, "label '" .. node.label .. "' already defined at " .. filename) | |
end | |
local unresolved = st[#st]["@unresolved"] | |
node.type = a_type({ y = node.y, x = node.x, typename = "none" }) | |
local var = add_var(node, label_id, node.type) | |
if unresolved then | |
if unresolved.t.labels[node.label] then | |
var.used = true | |
end | |
unresolved.t.labels[node.label] = nil | |
end | |
end, | |
}, | |
["goto"] = { | |
after = function(node, _children) | |
if not find_var_type("::" .. node.label .. "::") then | |
local unresolved = st[#st]["@unresolved"] and st[#st]["@unresolved"].t | |
if not unresolved then | |
unresolved = { typename = "unresolved", labels = {}, nominals = {} } | |
add_var(node, "@unresolved", unresolved) | |
end | |
unresolved.labels[node.label] = unresolved.labels[node.label] or {} | |
table.insert(unresolved.labels[node.label], node) | |
end | |
node.type = NONE | |
return node.type | |
end, | |
}, | |
["repeat"] = { | |
before = function() | |
widen_all_unions() | |
end, | |
after = end_scope_and_none_type, | |
}, | |
["forin"] = { | |
before = function(node) | |
begin_scope(node) | |
end, | |
before_statements = function(node) | |
local exp1 = node.exps[1] | |
local exp1type = resolve_tuple_and_nominal(exp1.type) | |
if exp1type.typename == "function" then | |
if exp1.op and exp1.op.op == "@funcall" then | |
local t = resolve_tuple_and_nominal(exp1.e2.type) | |
if exp1.e1.tk == "pairs" and is_array_type(t) then | |
node_warning("hint", exp1, "hint: applying pairs on an array: did you intend to apply ipairs?") | |
end | |
if exp1.e1.tk == "pairs" and t.typename ~= "map" then | |
if not (lax and is_unknown(t)) then | |
if is_record_type(t) then | |
match_all_record_field_names(exp1.e2, t, t.field_order, | |
"attempting pairs loop on a record with attributes of different types") | |
local ct = t.typename == "record" and "{string:any}" or "{any:any}" | |
node_warning("hint", exp1.e2, "hint: if you want to iterate over fields of a record, cast it to " .. ct) | |
else | |
node_error(exp1.e2, "cannot apply pairs on values of type: %s", exp1.e2.type) | |
end | |
end | |
elseif exp1.e1.tk == "ipairs" then | |
if t.typename == "tupletable" then | |
local arr_type = arraytype_from_tuple(exp1.e2, t) | |
if not arr_type then | |
node_error(exp1.e2, "attempting ipairs loop on tuple that's not a valid array: %s", exp1.e2.type) | |
end | |
elseif not is_array_type(t) then | |
if not (lax and (is_unknown(t) or t.typename == "emptytable")) then | |
node_error(exp1.e2, "attempting ipairs loop on something that's not an array: %s", exp1.e2.type) | |
end | |
end | |
end | |
end | |
local last | |
local rets = exp1type.rets | |
for i, v in ipairs(node.vars) do | |
local r = rets[i] | |
if not r then | |
if rets.is_va then | |
r = last | |
else | |
r = lax and UNKNOWN or INVALID | |
end | |
end | |
add_var(v, v.tk, r) | |
last = r | |
end | |
if (not lax) and (not rets.is_va and #node.vars > #rets) then | |
local nrets = #rets | |
local at = node.vars[nrets + 1] | |
local n_values = nrets == 1 and "1 value" or tostring(nrets) .. " value%s" | |
node_error(at, "too many variables for this iterator; it produces " .. n_values) | |
end | |
else | |
if not (lax and is_unknown(exp1type)) then | |
node_error(exp1, "expression in for loop does not return an iterator") | |
end | |
end | |
end, | |
after = end_scope_and_none_type, | |
}, | |
["fornum"] = { | |
before_statements = function(node, children) | |
begin_scope(node) | |
local from_t = resolve_tuple_and_nominal(children[2]) | |
local to_t = resolve_tuple_and_nominal(children[3]) | |
local step_t = children[4] and resolve_tuple_and_nominal(children[4]) | |
local t = (from_t.typename == "integer" and | |
to_t.typename == "integer" and | |
(not step_t or step_t.typename == "integer")) and | |
INTEGER or | |
NUMBER | |
add_var(node.var, node.var.tk, t) | |
end, | |
after = end_scope_and_none_type, | |
}, | |
["return"] = { | |
after = function(node, children) | |
local rets = find_var_type("@return") | |
if not rets then | |
rets = children[1] | |
rets.inferred_at = node | |
rets.inferred_at_file = filename | |
module_type = resolve_tuple_and_nominal(rets) | |
module_type.tk = nil | |
st[2]["@return"] = { t = rets } | |
end | |
local what = "in return value" | |
if rets.inferred_at then | |
what = what .. inferred_msg(rets) | |
end | |
local nrets = #rets | |
local vatype | |
if nrets > 0 then | |
vatype = rets.is_va and rets[nrets] | |
end | |
if #children[1] > nrets and (not lax) and not vatype then | |
node_error(node, "in " .. what .. ": excess return values, expected " .. #rets .. " %s, got " .. #children[1] .. " %s", rets, children[1]) | |
end | |
for i = 1, #children[1] do | |
local expected = rets[i] or vatype | |
if expected then | |
expected = resolve_tuple(expected) | |
local where = (node.exps[i] and node.exps[i].x) and | |
node.exps[i] or | |
node.exps | |
assert(where and where.x) | |
assert_is_a(where, children[1][i], expected, what) | |
end | |
end | |
node.type = NONE | |
return node.type | |
end, | |
}, | |
["variable_list"] = { | |
after = function(node, children) | |
node.type = TUPLE(children) | |
local n = #children | |
if n > 0 and children[n].typename == "tuple" then | |
if children[n].is_va then | |
node.type.is_va = true | |
end | |
local tuple = children[n] | |
for i, c in ipairs(tuple) do | |
children[n + i - 1] = c | |
end | |
end | |
return node.type | |
end, | |
}, | |
["table_literal"] = { | |
before = function(node) | |
if node.expected then | |
if node.expected.typename == "tupletable" then | |
for _, child in ipairs(node) do | |
local n = child.key.constnum | |
if n and is_positive_int(n) then | |
child.value.expected = node.expected.types[n] | |
end | |
end | |
elseif is_array_type(node.expected) then | |
for _, child in ipairs(node) do | |
if child.key.constnum then | |
child.value.expected = node.expected.elements | |
end | |
end | |
elseif node.expected.typename == "map" then | |
for _, child in ipairs(node) do | |
child.key.expected = node.expected.keys | |
child.value.expected = node.expected.values | |
end | |
end | |
if is_record_type(node.expected) then | |
for _, child in ipairs(node) do | |
if child.key.conststr then | |
child.value.expected = node.expected.fields[child.key.conststr] | |
end | |
end | |
end | |
end | |
end, | |
after = function(node, children) | |
node.known = FACT_TRUTHY | |
if node.expected then | |
local decltype = resolve_tuple_and_nominal(node.expected) | |
if decltype.typename == "union" then | |
for _, t in ipairs(decltype.types) do | |
local rt = resolve_tuple_and_nominal(t) | |
if is_known_table_type(rt) then | |
node.expected = t | |
decltype = rt | |
break | |
end | |
end | |
if decltype.typename == "union" then | |
node_error(node, "unexpected table literal, expected: %s", decltype) | |
end | |
end | |
if not is_known_table_type(decltype) then | |
node.type = infer_table_literal(node, children) | |
return node.type | |
end | |
local is_record = is_record_type(decltype) | |
local is_array = is_array_type(decltype) | |
local is_tupletable = decltype.typename == "tupletable" | |
local is_map = decltype.typename == "map" | |
local force_array = nil | |
local seen_keys = {} | |
for i, child in ipairs(children) do | |
assert(child.typename == "table_item") | |
local cvtype = resolve_tuple(child.vtype) | |
local ck = child.kname | |
local n = node[i].key.constnum | |
check_redeclared_key(node.expected_context, node[i], seen_keys, ck, n) | |
if is_record and ck then | |
local df = decltype.fields[ck] | |
if not df then | |
node_error(node[i], in_context(node.expected_context, "unknown field " .. ck)) | |
else | |
assert_is_a(node[i], cvtype, df, "in record field", ck) | |
end | |
elseif is_tupletable and is_number_type(child.ktype) then | |
local dt = decltype.types[n] | |
if not n then | |
node_error(node[i], in_context(node.expected_context, "unknown index in tuple %s"), decltype) | |
elseif not dt then | |
node_error(node[i], in_context(node.expected_context, "unexpected index " .. n .. " in tuple %s"), decltype) | |
else | |
assert_is_a(node[i], cvtype, dt, in_context(node.expected_context, "in tuple"), "at index " .. tostring(n)) | |
end | |
elseif is_array and is_number_type(child.ktype) then | |
if child.vtype.typename == "tuple" and i == #children and node[i].key_parsed == "implicit" then | |
for ti, tt in ipairs(child.vtype) do | |
assert_is_a(node[i], tt, decltype.elements, in_context(node.expected_context, "expected an array"), "at index " .. tostring(i + ti - 1)) | |
end | |
else | |
assert_is_a(node[i], cvtype, decltype.elements, in_context(node.expected_context, "expected an array"), "at index " .. tostring(n)) | |
end | |
elseif node[i].key_parsed == "implicit" then | |
force_array = expand_type(node[i], force_array, child.vtype) | |
elseif is_map then | |
assert_is_a(node[i], child.ktype, decltype.keys, in_context(node.expected_context, "in map key")) | |
assert_is_a(node[i], cvtype, decltype.values, in_context(node.expected_context, "in map value")) | |
else | |
node_error(node[i], in_context(node.expected_context, "unexpected key of type %s in table of type %s"), child.ktype, decltype) | |
end | |
end | |
if force_array then | |
node.type = a_type({ | |
inferred_at = node, | |
inferred_at_file = filename, | |
typename = "array", | |
elements = force_array, | |
}) | |
else | |
node.type = resolve_typevars_at(node.expected, node) | |
end | |
else | |
node.type = infer_table_literal(node, children) | |
end | |
return node.type | |
end, | |
}, | |
["table_item"] = { | |
after = function(node, children) | |
local kname = node.key.conststr | |
local ktype = children[1] | |
local vtype = children[2] | |
if node.decltype then | |
vtype = node.decltype | |
assert_is_a(node.value, children[2], node.decltype, "in table item") | |
end | |
node.type = a_type({ | |
y = node.y, | |
x = node.x, | |
typename = "table_item", | |
kname = kname, | |
ktype = ktype, | |
vtype = vtype, | |
}) | |
return node.type | |
end, | |
}, | |
["local_function"] = { | |
before = function(node) | |
reserve_symbol_list_slot(node) | |
begin_scope(node) | |
end, | |
before_statements = function(node) | |
add_internal_function_variables(node) | |
add_function_definition_for_recursion(node) | |
end, | |
after = function(node, children) | |
end_function_scope(node) | |
local rets = get_rets(children[3]) | |
add_var(node, node.name.tk, a_type({ | |
y = node.y, | |
x = node.x, | |
typename = "function", | |
typeargs = node.typeargs, | |
args = children[2], | |
rets = rets, | |
filename = filename, | |
})) | |
return node.type | |
end, | |
}, | |
["global_function"] = { | |
before = function(node) | |
begin_scope(node) | |
end, | |
before_statements = function(node) | |
add_internal_function_variables(node) | |
add_function_definition_for_recursion(node) | |
end, | |
after = function(node, children) | |
end_function_scope(node) | |
add_global(node, node.name.tk, a_type({ | |
y = node.y, | |
x = node.x, | |
typename = "function", | |
typeargs = node.typeargs, | |
args = children[2], | |
rets = get_rets(children[3]), | |
filename = filename, | |
})) | |
return node.type | |
end, | |
}, | |
["record_function"] = { | |
before = function(node) | |
begin_scope(node) | |
end, | |
before_statements = function(node, children) | |
add_internal_function_variables(node) | |
local rtype = resolve_tuple_and_nominal(resolve_typetype(children[1])) | |
local owner = find_record_to_extend(node.fn_owner) | |
if node.is_method then | |
children[3][1] = rtype | |
add_var(nil, "self", rtype) | |
end | |
if rtype.typename == "emptytable" then | |
rtype.typename = "record" | |
rtype.fields = {} | |
rtype.field_order = {} | |
end | |
if is_record_type(rtype) then | |
local fn_type = a_type({ | |
y = node.y, | |
x = node.x, | |
typename = "function", | |
is_method = node.is_method, | |
typeargs = node.typeargs, | |
args = children[3], | |
rets = get_rets(children[4]), | |
filename = filename, | |
}) | |
local ok = false | |
if lax then | |
ok = true | |
elseif rtype.fields[node.name.tk] and is_a(fn_type, rtype.fields[node.name.tk]) then | |
ok = true | |
elseif owner == rtype then | |
ok = true | |
end | |
if ok then | |
rtype.fields[node.name.tk] = fn_type | |
table.insert(rtype.field_order, node.name.tk) | |
node.name.type = fn_type | |
else | |
local name = tl.pretty_print_ast(node.fn_owner, { preserve_indent = true, preserve_newlines = false }) | |
node_error(node, "cannot add undeclared function '" .. node.name.tk .. "' outside of the scope where '" .. name .. "' was originally declared") | |
end | |
else | |
if not (lax and rtype.typename == "unknown") then | |
node_error(node, "not a module: %s", rtype) | |
end | |
end | |
end, | |
after = function(node, _children) | |
end_function_scope(node) | |
node.type = NONE | |
return node.type | |
end, | |
}, | |
["function"] = { | |
before = function(node) | |
begin_scope(node) | |
end, | |
before_statements = function(node) | |
add_internal_function_variables(node) | |
end, | |
after = function(node, children) | |
end_function_scope(node) | |
node.type = a_type({ | |
y = node.y, | |
x = node.x, | |
typename = "function", | |
typeargs = node.typeargs, | |
args = children[1], | |
rets = children[2], | |
filename = filename, | |
}) | |
return node.type | |
end, | |
}, | |
["cast"] = { | |
after = function(node, _children) | |
node.type = node.casttype | |
return node.type | |
end, | |
}, | |
["paren"] = { | |
after = function(node, children) | |
node.known = node.e1 and node.e1.known | |
node.type = resolve_tuple(children[1]) | |
return node.type | |
end, | |
}, | |
["op"] = { | |
before = function() | |
begin_scope() | |
end, | |
before_e2 = function(node) | |
if node.op.op == "and" then | |
apply_facts(node, node.e1.known) | |
elseif node.op.op == "or" then | |
apply_facts(node, facts_not(node.e1.known, node)) | |
elseif node.op.op == "@funcall" then | |
if node.e1.type.typename == "function" then | |
local argdelta = (node.e1.op and node.e1.op.op == ":") and -1 or 0 | |
for i, typ in ipairs(node.e1.type.args) do | |
if node.e2[i + argdelta] then | |
node.e2[i + argdelta].expected = typ | |
end | |
end | |
end | |
apply_facts(node, facts_not(node.e1.known, node)) | |
elseif node.op.op == "@index" then | |
if node.e1.type.typename == "map" then | |
node.e2.expected = node.e1.type.keys | |
end | |
end | |
end, | |
after = function(node, children) | |
end_scope() | |
local a = children[1] | |
local b = children[3] | |
local orig_a = a | |
local orig_b = b | |
local ra = a and resolve_tuple_and_nominal(a) | |
local rb = b and resolve_tuple_and_nominal(b) | |
if ra and is_typetype(ra) and ra.def.typename == "record" then | |
ra = ra.def | |
end | |
if rb and is_typetype(rb) and rb.def.typename == "record" then | |
rb = rb.def | |
end | |
if node.op.op == "." then | |
a = ra | |
if a.typename == "map" then | |
if is_a(a.keys, STRING) or is_a(a.keys, ANY) then | |
node.type = a.values | |
else | |
node_error(node, "cannot use . index, expects keys of type %s", a.keys) | |
end | |
else | |
node.type = match_record_key(node, a, { y = node.e2.y, x = node.e2.x, kind = "string", tk = node.e2.tk }, orig_a) | |
if node.type.needs_compat and opts.gen_compat ~= "off" then | |
if node.e1.kind == "variable" and node.e2.kind == "identifier" then | |
local key = node.e1.tk .. "." .. node.e2.tk | |
node.kind = "variable" | |
node.tk = "_tl_" .. node.e1.tk .. "_" .. node.e2.tk | |
all_needs_compat[key] = true | |
end | |
end | |
end | |
elseif node.op.op == "@funcall" then | |
node.type = type_check_funcall(node, a, b) | |
elseif node.op.op == "@index" then | |
node.type = type_check_index(node, node.e2, a, b) | |
elseif node.op.op == "as" then | |
node.type = b | |
elseif node.op.op == "is" then | |
if ra.typename == "typetype" then | |
node_error(node, "can only use 'is' on variables, not types") | |
elseif node.e1.kind == "variable" then | |
node.known = Fact({ fact = "is", var = node.e1.tk, typ = b, where = node }) | |
else | |
node_error(node, "can only use 'is' on variables") | |
end | |
node.type = BOOLEAN | |
elseif node.op.op == ":" then | |
node.type = match_record_key(node, node.e1.type, node.e2, orig_a) | |
elseif node.op.op == "not" then | |
node.known = facts_not(node.e1.known, node) | |
node.type = BOOLEAN | |
elseif node.op.op == "and" then | |
node.known = facts_and(node.e1.known, node.e2.known, node) | |
node.type = resolve_tuple(b) | |
elseif node.op.op == "or" and is_known_table_type(ra) and b.typename == "emptytable" then | |
node.known = nil | |
node.type = resolve_tuple(a) | |
elseif node.op.op == "or" and is_a(rb, ra) then | |
node.known = facts_or(node.e1.known, node.e2.known) | |
node.type = resolve_tuple(a) | |
elseif node.op.op == "or" and b.typename == "nil" then | |
node.known = nil | |
node.type = resolve_tuple(a) | |
elseif node.op.op == "or" and | |
((ra.typename == "enum" and rb.typename == "string" and is_a(rb, ra)) or | |
(ra.typename == "string" and rb.typename == "enum" and is_a(ra, rb))) then | |
node.known = nil | |
node.type = (ra.typename == "enum" and ra or rb) | |
elseif node.op.op == "or" and node.expected and node.expected.typename == "union" then | |
node.known = facts_or(node.e1.known, node.e2.known) | |
local u = unite({ ra, rb }, true) | |
local valid, err = is_valid_union(u) | |
node.type = valid and u or node_error(node, err) | |
elseif node.op.op == "==" or node.op.op == "~=" then | |
node.type = BOOLEAN | |
if is_a(b, a, true) or a.typename == "typevar" then | |
if node.op.op == "==" and node.e1.kind == "variable" then | |
node.known = Fact({ fact = "==", var = node.e1.tk, typ = b, where = node }) | |
end | |
elseif is_a(a, b, true) or b.typename == "typevar" then | |
if node.op.op == "==" and node.e2.kind == "variable" then | |
node.known = Fact({ fact = "==", var = node.e2.tk, typ = a, where = node }) | |
end | |
elseif lax and (is_unknown(a) or is_unknown(b)) then | |
node.type = UNKNOWN | |
else | |
node_error(node, "types are not comparable for equality: %s and %s", a, b) | |
end | |
elseif node.op.arity == 1 and unop_types[node.op.op] then | |
a = ra | |
if a.typename == "union" then | |
a = unite(a.types, true) | |
end | |
local types_op = unop_types[node.op.op] | |
node.type = types_op[a.typename] | |
local metamethod | |
if node.type then | |
if node.type.typename ~= "boolean" then | |
node.known = FACT_TRUTHY | |
end | |
else | |
metamethod = a.meta_fields and a.meta_fields[unop_to_metamethod[node.op.op] or ""] | |
if metamethod then | |
node.type = resolve_tuple_and_nominal(type_check_function_call(node, metamethod, { a }, false, 0)) | |
elseif lax and is_unknown(a) then | |
node.type = UNKNOWN | |
else | |
node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' on type %s", resolve_tuple(orig_a)) | |
end | |
end | |
if node.op.op == "~" and env.gen_target == "5.1" then | |
if metamethod then | |
all_needs_compat["mt"] = true | |
convert_node_to_compat_mt_call(node, unop_to_metamethod[node.op.op], 1, node.e1) | |
else | |
all_needs_compat["bit32"] = true | |
convert_node_to_compat_call(node, "bit32", "bnot", node.e1) | |
end | |
end | |
elseif node.op.arity == 2 and binop_types[node.op.op] then | |
if node.op.op == "or" then | |
node.known = facts_or(node.e1.known, node.e2.known) | |
end | |
a = ra | |
b = rb | |
if a.typename == "union" then | |
a = unite(a.types, true) | |
end | |
if b.typename == "union" then | |
b = unite(b.types, true) | |
end | |
local types_op = binop_types[node.op.op] | |
node.type = types_op[a.typename] and types_op[a.typename][b.typename] | |
local metamethod | |
local meta_self = 1 | |
if node.type then | |
if types_op == numeric_binop or node.op.op == ".." then | |
node.known = FACT_TRUTHY | |
end | |
else | |
metamethod = a.meta_fields and a.meta_fields[binop_to_metamethod[node.op.op] or ""] | |
if not metamethod then | |
metamethod = b.meta_fields and b.meta_fields[binop_to_metamethod[node.op.op] or ""] | |
meta_self = 2 | |
end | |
if metamethod then | |
node.type = resolve_tuple_and_nominal(type_check_function_call(node, metamethod, { a, b }, false, 0)) | |
elseif lax and (is_unknown(a) or is_unknown(b)) then | |
node.type = UNKNOWN | |
else | |
node_error(node, "cannot use operator '" .. node.op.op:gsub("%%", "%%%%") .. "' for types %s and %s", resolve_tuple(orig_a), resolve_tuple(orig_b)) | |
end | |
end | |
if node.op.op == "//" and env.gen_target == "5.1" then | |
if metamethod then | |
all_needs_compat["mt"] = true | |
convert_node_to_compat_mt_call(node, "__idiv", meta_self, node.e1, node.e2) | |
else | |
local div = { y = node.y, x = node.x, kind = "op", op = an_operator(node, 2, "/"), e1 = node.e1, e2 = node.e2 } | |
convert_node_to_compat_call(node, "math", "floor", div) | |
end | |
elseif bit_operators[node.op.op] and env.gen_target == "5.1" then | |
if metamethod then | |
all_needs_compat["mt"] = true | |
convert_node_to_compat_mt_call(node, binop_to_metamethod[node.op.op], meta_self, node.e1, node.e2) | |
else | |
all_needs_compat["bit32"] = true | |
convert_node_to_compat_call(node, "bit32", bit_operators[node.op.op], node.e1, node.e2) | |
end | |
end | |
else | |
error("unknown node op " .. node.op.op) | |
end | |
return node.type | |
end, | |
}, | |
["variable"] = { | |
after = function(node, _children) | |
if node.tk == "..." then | |
local va_sentinel = find_var_type("@is_va") | |
if not va_sentinel or va_sentinel.typename == "nil" then | |
node_error(node, "cannot use '...' outside a vararg function") | |
end | |
end | |
if node.tk == "_G" then | |
node.type, node.is_const = simulate_g() | |
else | |
node.type, node.is_const = find_var_type(node.tk) | |
end | |
if node.type and is_typetype(node.type) then | |
node.type = a_type({ | |
y = node.y, | |
x = node.x, | |
typename = "nominal", | |
names = { node.tk }, | |
found = node.type, | |
resolved = node.type, | |
}) | |
end | |
if node.type == nil then | |
node.type = a_type({ typename = "unknown" }) | |
if lax then | |
add_unknown(node, node.tk) | |
else | |
node_error(node, "unknown variable: " .. node.tk) | |
end | |
end | |
return node.type | |
end, | |
}, | |
["type_identifier"] = { | |
after = function(node, _children) | |
node.type, node.is_const = find_var_type(node.tk) | |
if node.type == nil then | |
if lax then | |
node.type = UNKNOWN | |
add_unknown(node, node.tk) | |
else | |
node_error(node, "unknown variable: " .. node.tk) | |
end | |
end | |
return node.type | |
end, | |
}, | |
["argument"] = { | |
after = function(node, _children) | |
local t = node.decltype | |
if not t then | |
t = UNKNOWN | |
end | |
if node.tk == "..." then | |
t = a_type({ typename = "tuple", is_va = true, t }) | |
end | |
add_var(node, node.tk, t).is_func_arg = true | |
return node.type | |
end, | |
}, | |
["identifier"] = { | |
after = function(node, _children) | |
node.type = node.type or NONE | |
return node.type | |
end, | |
}, | |
["newtype"] = { | |
after = function(node, _children) | |
node.type = node.type or node.newtype | |
return node.type | |
end, | |
}, | |
["error_node"] = { | |
after = function(node, _children) | |
node.type = INVALID | |
return node.type | |
end, | |
}, | |
} | |
visit_node.cbs["string"] = { | |
after = function(node, _children) | |
node.type = a_type({ | |
y = node.y, | |
x = node.x, | |
typename = node.kind, | |
tk = node.tk, | |
}) | |
node.known = FACT_TRUTHY | |
return node.type | |
end, | |
} | |
visit_node.cbs["number"] = visit_node.cbs["string"] | |
visit_node.cbs["integer"] = visit_node.cbs["string"] | |
visit_node.cbs["boolean"] = { | |
after = function(node, _children) | |
node.type = a_type({ | |
y = node.y, | |
x = node.x, | |
typename = node.kind, | |
tk = node.tk, | |
}) | |
if node.tk == "true" then | |
node.known = FACT_TRUTHY | |
end | |
return node.type | |
end, | |
} | |
visit_node.cbs["nil"] = visit_node.cbs["boolean"] | |
visit_node.cbs["do"] = visit_node.cbs["if"] | |
visit_node.cbs["..."] = visit_node.cbs["variable"] | |
visit_node.cbs["break"] = visit_node.cbs["if"] | |
visit_node.cbs["argument_list"] = visit_node.cbs["variable_list"] | |
visit_node.cbs["expression_list"] = visit_node.cbs["variable_list"] | |
visit_node.after = function(node, _children) | |
if type(node.type) ~= "table" then | |
error(node.kind .. " did not produce a type") | |
end | |
if type(node.type.typename) ~= "string" then | |
error(node.kind .. " type does not have a typename") | |
end | |
return node.type | |
end | |
local visit_type = { | |
cbs = { | |
["string"] = { | |
after = function(typ, _children) | |
return typ | |
end, | |
}, | |
["function"] = { | |
before = function(_typ, _children) | |
begin_scope() | |
end, | |
after = function(typ, _children) | |
end_scope() | |
return typ | |
end, | |
}, | |
["record"] = { | |
before = function(typ, _children) | |
begin_scope() | |
for name, typ2 in fields_of(typ) do | |
if typ2.typename == "typetype" then | |
typ2.typename = "nestedtype" | |
local resolved, is_alias = resolve_nominal_typetype(typ2) | |
if is_alias then | |
typ2.is_alias = true | |
typ2.def.resolved = resolved | |
end | |
add_var(nil, name, resolved) | |
end | |
end | |
end, | |
after = function(typ, _children) | |
end_scope() | |
for _, typ2 in fields_of(typ) do | |
if typ2.typename == "nestedtype" then | |
typ2.typename = "typetype" | |
end | |
end | |
return typ | |
end, | |
}, | |
["typearg"] = { | |
after = function(typ, _children) | |
add_var(nil, typ.typearg, a_type({ | |
y = typ.y, | |
x = typ.x, | |
typename = "typearg", | |
typearg = typ.typearg, | |
})) | |
return typ | |
end, | |
}, | |
["typevar"] = { | |
after = function(typ, _children) | |
if not find_var_type(typ.typevar) then | |
type_error(typ, "undefined type variable " .. typ.typevar) | |
end | |
return typ | |
end, | |
}, | |
["nominal"] = { | |
after = function(typ, _children) | |
if typ.found then | |
return typ | |
end | |
local t = find_type(typ.names, true) | |
if t then | |
if t.typename == "typearg" then | |
typ.names = nil | |
typ.typename = "typevar" | |
typ.typevar = t.typearg | |
else | |
typ.found = t | |
end | |
else | |
local name = typ.names[1] | |
local unresolved = find_var_type("@unresolved") | |
if not unresolved then | |
unresolved = { typename = "unresolved", labels = {}, nominals = {} } | |
add_var(nil, "@unresolved", unresolved) | |
end | |
unresolved.nominals[name] = unresolved.nominals[name] or {} | |
table.insert(unresolved.nominals[name], typ) | |
end | |
return typ | |
end, | |
}, | |
["union"] = { | |
after = function(typ, _children) | |
local valid, err = is_valid_union(typ) | |
if not valid then | |
type_error(typ, err, typ) | |
end | |
return typ | |
end, | |
}, | |
}, | |
after = function(typ, _children, ret) | |
if type(ret) ~= "table" then | |
error(typ.typename .. " did not produce a type") | |
end | |
if type(ret.typename) ~= "string" then | |
error("type node does not have a typename") | |
end | |
return ret | |
end, | |
} | |
if not opts.run_internal_compiler_checks then | |
visit_node.after = nil | |
visit_type.after = nil | |
end | |
visit_type.cbs["tupletable"] = visit_type.cbs["string"] | |
visit_type.cbs["typetype"] = visit_type.cbs["string"] | |
visit_type.cbs["nestedtype"] = visit_type.cbs["string"] | |
visit_type.cbs["array"] = visit_type.cbs["string"] | |
visit_type.cbs["map"] = visit_type.cbs["string"] | |
visit_type.cbs["arrayrecord"] = visit_type.cbs["record"] | |
visit_type.cbs["enum"] = visit_type.cbs["string"] | |
visit_type.cbs["boolean"] = visit_type.cbs["string"] | |
visit_type.cbs["nil"] = visit_type.cbs["string"] | |
visit_type.cbs["number"] = visit_type.cbs["string"] | |
visit_type.cbs["integer"] = visit_type.cbs["string"] | |
visit_type.cbs["thread"] = visit_type.cbs["string"] | |
visit_type.cbs["bad_nominal"] = visit_type.cbs["string"] | |
visit_type.cbs["emptytable"] = visit_type.cbs["string"] | |
visit_type.cbs["table_item"] = visit_type.cbs["string"] | |
visit_type.cbs["unresolved_emptytable_value"] = visit_type.cbs["string"] | |
visit_type.cbs["tuple"] = visit_type.cbs["string"] | |
visit_type.cbs["poly"] = visit_type.cbs["string"] | |
visit_type.cbs["any"] = visit_type.cbs["string"] | |
visit_type.cbs["unknown"] = visit_type.cbs["string"] | |
visit_type.cbs["invalid"] = visit_type.cbs["string"] | |
visit_type.cbs["unresolved"] = visit_type.cbs["string"] | |
visit_type.cbs["none"] = visit_type.cbs["string"] | |
assert(ast.kind == "statements") | |
recurse_node(ast, visit_node, visit_type) | |
close_types(st[1]) | |
check_for_unused_vars(st[1]) | |
clear_redundant_errors(errors) | |
add_compat_entries(ast, all_needs_compat, env.gen_compat) | |
local result = { | |
ast = ast, | |
env = env, | |
type = module_type, | |
filename = filename, | |
warnings = warnings, | |
type_errors = errors, | |
symbol_list = symbol_list, | |
dependencies = dependencies, | |
} | |
env.loaded[filename] = result | |
table.insert(env.loaded_order, filename) | |
return result | |
end | |
local typename_to_typecode = { | |
["typevar"] = tl.typecodes.TYPE_VARIABLE, | |
["typearg"] = tl.typecodes.TYPE_VARIABLE, | |
["function"] = tl.typecodes.FUNCTION, | |
["array"] = tl.typecodes.ARRAY, | |
["map"] = tl.typecodes.MAP, | |
["tupletable"] = tl.typecodes.TUPLE, | |
["arrayrecord"] = tl.typecodes.ARRAYRECORD, | |
["record"] = tl.typecodes.RECORD, | |
["enum"] = tl.typecodes.ENUM, | |
["boolean"] = tl.typecodes.BOOLEAN, | |
["string"] = tl.typecodes.STRING, | |
["nil"] = tl.typecodes.NIL, | |
["thread"] = tl.typecodes.THREAD, | |
["number"] = tl.typecodes.NUMBER, | |
["integer"] = tl.typecodes.INTEGER, | |
["union"] = tl.typecodes.IS_UNION, | |
["nominal"] = tl.typecodes.NOMINAL, | |
["emptytable"] = tl.typecodes.EMPTY_TABLE, | |
["unresolved_emptytable_value"] = tl.typecodes.EMPTY_TABLE, | |
["poly"] = tl.typecodes.IS_POLY, | |
["any"] = tl.typecodes.ANY, | |
["unknown"] = tl.typecodes.UNKNOWN, | |
["invalid"] = tl.typecodes.INVALID, | |
} | |
function tl.get_types(result, trenv) | |
local filename = result.filename or "?" | |
local function mark_array(x) | |
local arr = x | |
arr[0] = false | |
return x | |
end | |
if not trenv then | |
trenv = { | |
next_num = 1, | |
typeid_to_num = {}, | |
tr = { | |
by_pos = {}, | |
types = {}, | |
symbols = mark_array({}), | |
globals = {}, | |
}, | |
} | |
end | |
local tr = trenv.tr | |
local typeid_to_num = trenv.typeid_to_num | |
local get_typenum | |
local function store_function(ti, rt) | |
local args = {} | |
for _, fnarg in ipairs(rt.args) do | |
table.insert(args, mark_array({ get_typenum(fnarg), nil })) | |
end | |
ti.args = mark_array(args) | |
local rets = {} | |
for _, fnarg in ipairs(rt.rets) do | |
table.insert(rets, mark_array({ get_typenum(fnarg), nil })) | |
end | |
ti.rets = mark_array(rets) | |
ti.vararg = not not rt.is_va | |
end | |
get_typenum = function(t) | |
assert(t.typeid) | |
local n = typeid_to_num[t.typeid] | |
if n then | |
return n | |
end | |
n = trenv.next_num | |
local rt = t | |
if rt.typename == "typetype" or rt.typename == "nestedtype" then | |
rt = rt.def | |
elseif rt.typename == "tuple" and #rt == 1 then | |
rt = rt[1] | |
end | |
local ti = { | |
t = assert(typename_to_typecode[rt.typename]), | |
str = show_type(t, true), | |
file = t.filename, | |
y = t.y, | |
x = t.x, | |
} | |
tr.types[n] = ti | |
typeid_to_num[t.typeid] = n | |
trenv.next_num = trenv.next_num + 1 | |
if t.found then | |
ti.ref = get_typenum(t.found) | |
end | |
if t.resolved then | |
rt = t | |
end | |
assert(rt.typename ~= "typetype") | |
if is_record_type(rt) then | |
local r = {} | |
for _, k in ipairs(rt.field_order) do | |
local v = rt.fields[k] | |
r[k] = get_typenum(v) | |
end | |
ti.fields = r | |
end | |
if is_array_type(rt) then | |
ti.elements = get_typenum(rt.elements) | |
end | |
if rt.typename == "map" then | |
ti.keys = get_typenum(rt.keys) | |
ti.values = get_typenum(rt.values) | |
elseif rt.typename == "enum" then | |
ti.enums = mark_array(sorted_keys(rt.enumset)) | |
elseif rt.typename == "function" then | |
store_function(ti, rt) | |
elseif rt.typename == "poly" or rt.typename == "union" or rt.typename == "tupletable" then | |
local tis = {} | |
for _, pt in ipairs(rt.types) do | |
table.insert(tis, get_typenum(pt)) | |
end | |
ti.types = mark_array(tis) | |
end | |
return n | |
end | |
local visit_node = { allow_missing_cbs = true } | |
local visit_type = { allow_missing_cbs = true } | |
local skip = { | |
["none"] = true, | |
["tuple"] = true, | |
["table_item"] = true, | |
} | |
local ft = {} | |
tr.by_pos[filename] = ft | |
local function store(y, x, typ) | |
if not typ or skip[typ.typename] then | |
return | |
end | |
local yt = ft[y] | |
if not yt then | |
yt = {} | |
ft[y] = yt | |
end | |
yt[x] = get_typenum(typ) | |
end | |
visit_node.after = function(node) | |
store(node.y, node.x, node.type) | |
end | |
visit_type.after = function(typ) | |
store(typ.y or 0, typ.x or 0, typ) | |
end | |
recurse_node(result.ast, visit_node, visit_type) | |
tr.by_pos[filename][0] = nil | |
do | |
local n = 0 | |
local p = 0 | |
local n_stack, p_stack = {}, {} | |
local level = 0 | |
for i, s in ipairs(result.symbol_list) do | |
if s.typ then | |
n = n + 1 | |
elseif s.name == "@{" then | |
level = level + 1 | |
n_stack[level], p_stack[level] = n, p | |
n, p = 0, i | |
else | |
if n == 0 then | |
result.symbol_list[p].skip = true | |
s.skip = true | |
end | |
n, p = n_stack[level], p_stack[level] | |
level = level - 1 | |
end | |
end | |
end | |
do | |
local stack = {} | |
local level = 0 | |
local i = 0 | |
for _, s in ipairs(result.symbol_list) do | |
if not s.skip then | |
i = i + 1 | |
local id | |
if s.typ then | |
id = get_typenum(s.typ) | |
elseif s.name == "@{" then | |
level = level + 1 | |
stack[level] = i | |
id = -1 | |
else | |
local other = stack[level] | |
level = level - 1 | |
tr.symbols[other][4] = i | |
id = other - 1 | |
end | |
local sym = mark_array({ s.y, s.x, s.name, id }) | |
table.insert(tr.symbols, sym) | |
end | |
end | |
end | |
local gkeys = sorted_keys(result.env.globals) | |
for _, name in ipairs(gkeys) do | |
if name:sub(1, 1) ~= "@" then | |
local var = result.env.globals[name] | |
tr.globals[name] = get_typenum(var.t) | |
end | |
end | |
return tr, trenv | |
end | |
function tl.symbols_in_scope(tr, y, x) | |
local function find(symbols, at_y, at_x) | |
local function le(a, b) | |
return a[1] < b[1] or | |
(a[1] == b[1] and a[2] <= b[2]) | |
end | |
return binary_search(symbols, { at_y, at_x }, le) or 0 | |
end | |
local ret = {} | |
local n = find(tr.symbols, y, x) | |
local symbols = tr.symbols | |
while n >= 1 do | |
local s = symbols[n] | |
if s[3] == "@{" then | |
n = n - 1 | |
elseif s[3] == "@}" then | |
n = s[4] | |
else | |
ret[s[3]] = s[4] | |
n = n - 1 | |
end | |
end | |
return ret | |
end | |
tl.process = function(filename, env) | |
if env and env.loaded and env.loaded[filename] then | |
return env.loaded[filename] | |
end | |
local fd, err = io.open(filename, "r") | |
if not fd then | |
return nil, "could not open " .. filename .. ": " .. err | |
end | |
local input; input, err = fd:read("*a") | |
fd:close() | |
if not input then | |
return nil, "could not read " .. filename .. ": " .. err | |
end | |
local _, extension = filename:match("(.*)%.([a-z]+)$") | |
extension = extension and extension:lower() | |
local is_lua | |
if extension == "tl" then | |
is_lua = false | |
elseif extension == "lua" then | |
is_lua = true | |
else | |
is_lua = input:match("^#![^\n]*lua[^\n]*\n") | |
end | |
return tl.process_string(input, is_lua, env, filename) | |
end | |
function tl.process_string(input, is_lua, env, filename) | |
env = env or tl.init_env(is_lua) | |
if env.loaded and env.loaded[filename] then | |
return env.loaded[filename] | |
end | |
filename = filename or "" | |
local syntax_errors = {} | |
local tokens, errs = tl.lex(input) | |
if errs then | |
for _, err in ipairs(errs) do | |
table.insert(syntax_errors, { | |
y = err.y, | |
x = err.x, | |
msg = "invalid token '" .. err.tk .. "'", | |
filename = filename, | |
}) | |
end | |
end | |
local _, program = tl.parse_program(tokens, syntax_errors, filename) | |
if (not env.keep_going) and #syntax_errors > 0 then | |
local result = { | |
ok = false, | |
filename = filename, | |
type_errors = {}, | |
syntax_errors = syntax_errors, | |
env = env, | |
} | |
env.loaded[filename] = result | |
table.insert(env.loaded_order, filename) | |
return result | |
end | |
local opts = { | |
filename = filename, | |
lax = is_lua, | |
gen_compat = env.gen_compat, | |
env = env, | |
} | |
local result = tl.type_check(program, opts) | |
result.syntax_errors = syntax_errors | |
return result | |
end | |
tl.gen = function(input, env) | |
env = env or tl.init_env() | |
local result = tl.process_string(input, false, env) | |
if (not result.ast) or #result.syntax_errors > 0 then | |
return nil, result | |
end | |
return tl.pretty_print_ast(result.ast), result | |
end | |
local function tl_package_loader(module_name) | |
local found_filename, fd, tried = tl.search_module(module_name, false) | |
if found_filename then | |
local input = fd:read("*a") | |
if not input then | |
return table.concat(tried, "\n\t") | |
end | |
fd:close() | |
local errs = {} | |
local _, program = tl.parse_program(tl.lex(input), errs, module_name) | |
if #errs > 0 then | |
error(found_filename .. ":" .. errs[1].y .. ":" .. errs[1].x .. ": " .. errs[1].msg) | |
end | |
local lax = not not found_filename:match("lua$") | |
if not tl.package_loader_env then | |
tl.package_loader_env = tl.init_env(lax) | |
end | |
tl.type_check(program, { | |
lax = lax, | |
filename = found_filename, | |
env = tl.package_loader_env, | |
run_internal_compiler_checks = false, | |
}) | |
local code = tl.pretty_print_ast(program, true) | |
local chunk, err = load(code, module_name, "t") | |
if chunk then | |
return function() | |
local ret = chunk() | |
package.loaded[module_name] = ret | |
return ret | |
end | |
else | |
error("Internal Compiler Error: Teal generator produced invalid Lua. Please report a bug at https://github.com/teal-language/tl\n\n" .. err) | |
end | |
end | |
return table.concat(tried, "\n\t") | |
end | |
function tl.loader() | |
if package.searchers then | |
table.insert(package.searchers, 2, tl_package_loader) | |
else | |
table.insert(package.loaders, 2, tl_package_loader) | |
end | |
end | |
tl.load = function(input, chunkname, mode, env) | |
local tokens = tl.lex(input) | |
local errs = {} | |
local _, program = tl.parse_program(tokens, errs, chunkname) | |
if #errs > 0 then | |
return nil, (chunkname or "") .. ":" .. errs[1].y .. ":" .. errs[1].x .. ": " .. errs[1].msg | |
end | |
local code = tl.pretty_print_ast(program, true) | |
return load(code, chunkname, mode, env) | |
end | |
return tl |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment