Last active
May 8, 2019 11:38
-
-
Save rrika/fcfa43aee2ec95b0bd482e1c508f55b5 to your computer and use it in GitHub Desktop.
Trying to write a borrow checker
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Expr: pass | |
class Type: pass | |
class Stmt: pass | |
class RefMut(Type): | |
def __init__(self, lt, ty): | |
self.lt = lt | |
self.ty = ty | |
def __repr__(self): | |
return "&'{} {!r}".format(self.lt, self.ty) | |
class NewTypeExpr(Expr): | |
def __init__(self, value): self.value = value | |
class Deref(NewTypeExpr): | |
def __repr__(self): | |
return "*{!r}".format(self.value) | |
class TakeRefMut(NewTypeExpr): | |
def __repr__(self): | |
return "&mut {!r}".format(self.value) | |
class Var(NewTypeExpr): | |
def __repr__(self): | |
return self.value | |
class Const(Expr): | |
def __init__(self, ty, value): | |
self.ty = ty | |
self.value = value | |
def __repr__(self): | |
return "({!r}: {!r})".format(self.value, self.ty) | |
class Assign(Stmt): | |
def __init__(self, lhs, rhs): | |
self.lhs = lhs | |
self.rhs = rhs | |
def __repr__(self): | |
return "*{!r} = {!r}".format(self.lhs, self.rhs) | |
class Call(Stmt): | |
def __init__(self, fun, *args): | |
self.fun = fun | |
self.args = args | |
def exprs(self): | |
yield from self.args | |
class Return(Stmt): | |
def __init__(self, arg): | |
self.arg = arg | |
def exprs(self): | |
yield self.arg | |
class Function: | |
def __init__(self, name, lts, args): | |
self.name = name | |
self.lts = lts | |
self.args = args | |
self.vars = dict(args) | |
self.entry = BasicBlock() | |
class BasicBlock: | |
def __init__(self): | |
self.stmts = [] | |
self.succ = [] | |
class Builder: | |
def __init__(self, fun): | |
self.fun = fun | |
self.bb = fun.entry | |
def stmt(self, stmt): | |
self.bb.stmts.append(stmt) | |
def declare_mut(self, name, ty): | |
self.fun.vars[name] = ty | |
def visit_type(cs, ty): | |
if isinstance(ty, str): | |
return "static" | |
if isinstance(ty, RefMut): | |
vty = visit_type(cs, ty.ty) | |
cs.append((vty, ty.lt, "ref only lives as long as pointed-to type")) | |
return ty.lt | |
def generate_lt_constraints(fun): | |
cs = [] | |
counter = 0 | |
def tmp(kind="tmp"): | |
nonlocal counter | |
counter += 1 | |
return "{}{}".format(kind, counter) | |
borrows = {} | |
def record_borrow(var, now, mutable=True): | |
nonlocal borrows | |
prior_mut, lts = borrows.get(var, (False, [])) | |
if prior_mut: | |
for plt in lts: | |
cs.append((now, plt, "all prior borrows of {} must end".format(var))) | |
lts = [] | |
lt = tmp("borrow_"+var+"_") | |
if mutable: | |
borrows[var] = True, [lt] | |
else: | |
borrows[var] = False, lts+[lt] | |
return lt | |
def visit_expr(expr, borrow=False): | |
lt, ty = visit_expr_inner(expr, borrow) | |
visit_type(cs, ty) | |
return lt, ty | |
def visit_expr_inner(expr, borrow=False): | |
nonlocal time | |
if isinstance(expr, Deref): | |
lt = tmp("tmpderef_") | |
vlt, vty = visit_expr(expr.value) | |
return lt, vty.ty | |
elif isinstance(expr, TakeRefMut): | |
# todo: borrowing | |
vlt, vty = visit_expr(expr.value) | |
if False: | |
lt = tmp("tmpref_") | |
cs.append((vlt, lt, "ref taken at most as long as value exists")) | |
return lt, RefMut(lt, vty) | |
else: | |
return vlt, RefMut(vlt, vty) | |
elif isinstance(expr, Var): | |
lt = record_borrow(expr.value, time, True) | |
return lt, fun.vars[expr.value] | |
elif isinstance(expr, Const): | |
return "static", expr.ty | |
def assign_cs(l, r, lremap=None): | |
lr = isinstance(l, RefMut) | |
rr = isinstance(r, RefMut) | |
assert lr == rr | |
if lr and rr: | |
cs.append((r.lt, lremap[l.lt] if lremap else l.lt, "assignment may shorten lifetime")) | |
if assign_cs(l.ty, r.ty, lremap): # and mutable | |
cs.append((lremap[l.ty.lt] if lremap else l.ty.lt, r.ty.lt, "TODO explain this one")) | |
return True | |
return False | |
time = "begin" | |
def gen_for_stmt(stmt): | |
nonlocal time | |
after_stmt = tmp("stmt_") | |
cs.append((after_stmt, time, "stmt ordering")) | |
before_stmt = time | |
time = after_stmt | |
if isinstance(stmt, Assign): | |
l = llt, lty = visit_expr(stmt.lhs) | |
r = rlt, rty = visit_expr(stmt.rhs) | |
cs.append((lty.lt, after_stmt, "assignment target must be valid at use")) | |
cs.append((llt, after_stmt, "arg must be available at use (arg = ...)")) | |
cs.append((rlt, after_stmt, "arg must be available at use (... = arg)")) | |
assign_cs(lty.ty, rty) | |
elif isinstance(stmt, Call): | |
plts = {lt for ltab in stmt.fun.lts or [] for lt in ltab} | |
for argname, ty in stmt.fun.args: | |
while isinstance(ty, RefMut): | |
plts.add(ty.lt) | |
ty = ty.ty | |
del argname, ty | |
lt_insta = { | |
lt: | |
before_stmt if lt == "begin" else | |
after_stmt if lt == "return" else | |
tmp("call_{}_{}_".format(stmt.fun.name, lt)) | |
for lt in plts | |
} | |
for lta, ltb in stmt.fun.lts or []: | |
cs.append((lt_insta[lta], lt_insta[ltb], "interface lifetime bound for {} ({}: {})".format(stmt.fun.name, lta, ltb))) | |
for arg, (pname, pty) in zip(stmt.args, stmt.fun.args): | |
a = alt, aty = visit_expr(arg) | |
cs.append((alt, after_stmt, "arg must be available at use ({}({}=arg, ...))".format(stmt.fun.name, pname))) | |
assign_cs(pty, aty, lt_insta) | |
elif isinstance(stmt, Return): | |
a = alt, aty = visit_expr(stmt.arg) | |
cs.append((alt, after_stmt, "arg must be available at use (return)")) | |
for stmt in fun.entry.stmts: | |
gen_for_stmt(stmt) | |
for var, (mut, lts) in borrows.items(): | |
for lt in lts: | |
cs.append(("return", lt, "borrow of local var {} ends before end of function".format(var))) | |
cs.append(("return", time, "stmt ordering")) | |
# remove 'static : 'a | |
# remove 'a : 'a | |
cs = [c for c in cs if c[0] != "static" and c[0] != c[1]] | |
#cs = list(set(cs)) | |
return cs | |
def demo(fun): | |
print(fun.name) | |
cs = generate_lt_constraints(fun) | |
for c in cs: | |
print(" {}: {} // {}".format(*c)) | |
print() | |
public_lts = {"begin", "return"} | |
public_lts.update(lt for lt, bounds in fun.lts or []) | |
for argname, ty in fun.args: | |
while isinstance(ty, RefMut): | |
public_lts.add(ty.lt) | |
ty = ty.ty | |
import networkx as nx | |
g = nx.DiGraph() | |
g.add_edges_from((c[0], c[1]) for c in cs) | |
rg = g.reverse() | |
rgscc = list(nx.strongly_connected_components(rg)) | |
rgc = nx.condensation(rg, rgscc) | |
for n in rgc: | |
rgc.nodes[n]["label"] = "/".join(rgscc[n]) | |
nx.nx_pydot.write_dot(rgc, "{}_generated.dot".format(fun.name)) | |
tg = nx.transitive_closure(g) | |
tgp = nx.subgraph(tg, public_lts) | |
tgpr = nx.transitive_reduction(tgp) | |
ig = nx.DiGraph() | |
ig.add_edge("return", "begin") | |
ig.add_edges_from(fun.lts or []) | |
for c in ig.edges(): | |
print(" {}: {} // provided by signature".format(c[0], c[1])) | |
print() | |
tyreqs = [] | |
for argname, ty in fun.args: | |
visit_type(tyreqs, ty) | |
for c in tgpr.edges(): | |
provided = False if c[0] not in ig else c[1] in nx.descendants(ig, c[0]) | |
status = "OK" if provided else "MISSING" | |
print(" {}: {} // required from signature ({})".format(c[0], c[1], status)) | |
for c in tyreqs: | |
if c[0] == "static": continue | |
provided = False if c[0] not in ig else c[1] in nx.descendants(ig, c[0]) | |
status = "OK" if provided else "MISSING" | |
print(" {}: {} // required for argument type validity ({})".format(c[0], c[1], status)) | |
print() | |
return cs | |
""" | |
fun_set_ptr<'a: 'return, 'b: 'a>(x: &'a mut &'b mut i32, y: &'b mut i32) | |
*x = y | |
fun_set_val<'a: 'return>(x: &'c mut i32) | |
*x = 99 | |
fun_main() | |
let mut y: i32 = 0 | |
let mut x: &'tmpx mut i32 = undef | |
fun_set_ptr(&mut x, &mut y) | |
fun_set_val(x) | |
return y | |
""" | |
fun_set_ptr = Function("fun_set_ptr", | |
[["a", "return"], | |
["b", "a"]], | |
[["x", RefMut("a", RefMut("b", "i32"))], | |
["y", RefMut("b", "i32")]]) | |
Builder(fun_set_ptr).stmt(Assign(Var("x"), Var("y"))) | |
fun_set_val = Function("fun_set_val", | |
[["c", "return"]], | |
[["z", RefMut("c", "i32")]]) | |
Builder(fun_set_val).stmt(Assign(Var("z"), Const("i32", 99))) | |
fun_main = Function("fun_main", | |
None, | |
[]) | |
b = Builder(fun_main) | |
b.declare_mut("y", "i32") | |
b.declare_mut("x", RefMut("tmpx", "i32")) | |
b.stmt(Call(fun_set_ptr, TakeRefMut(Var("x")), TakeRefMut(Var("y")))) | |
b.stmt(Call(fun_set_val, Var("x"))) | |
b.stmt(Return(Var("y"))) | |
cs_fun_set_ptr = demo(fun_set_ptr) | |
cs_fun_set_val = demo(fun_set_val) | |
cs_fun_main = demo(fun_main) | |
""" | |
fun_set_ptr | |
stmt_1: begin // stmt ordering | |
b: a // ref only lives as long as pointed-to type | |
a: stmt_1 // assignment target must be valid at use | |
borrow_x_2: stmt_1 // arg must be available at use (arg = ...) | |
borrow_y_3: stmt_1 // arg must be available at use (... = arg) | |
return: borrow_x_2 // borrow of local var x ends before end of function | |
return: borrow_y_3 // borrow of local var y ends before end of function | |
return: stmt_1 // stmt ordering | |
return: begin // provided by signature | |
a: return // provided by signature | |
b: a // provided by signature | |
b: a // required from signature (OK) | |
a: begin // required from signature (OK) | |
return: begin // required from signature (OK) | |
b: a // required for argument type validity (OK) | |
fun_set_val | |
stmt_1: begin // stmt ordering | |
c: stmt_1 // assignment target must be valid at use | |
borrow_z_2: stmt_1 // arg must be available at use (arg = ...) | |
return: borrow_z_2 // borrow of local var z ends before end of function | |
return: stmt_1 // stmt ordering | |
return: begin // provided by signature | |
c: return // provided by signature | |
c: begin // required from signature (OK) | |
return: begin // required from signature (OK) | |
fun_main | |
stmt_1: begin // stmt ordering | |
call_fun_set_ptr_a_3: stmt_1 // interface lifetime bound for fun_set_ptr (a: return) | |
call_fun_set_ptr_b_2: call_fun_set_ptr_a_3 // interface lifetime bound for fun_set_ptr (b: a) | |
tmpx: borrow_x_4 // ref only lives as long as pointed-to type | |
borrow_x_4: stmt_1 // arg must be available at use (fun_set_ptr(x=arg, ...)) | |
borrow_x_4: call_fun_set_ptr_a_3 // assignment may shorten lifetime | |
tmpx: call_fun_set_ptr_b_2 // assignment may shorten lifetime | |
call_fun_set_ptr_b_2: tmpx // TODO explain this one | |
borrow_y_5: stmt_1 // arg must be available at use (fun_set_ptr(y=arg, ...)) | |
borrow_y_5: call_fun_set_ptr_b_2 // assignment may shorten lifetime | |
stmt_6: stmt_1 // stmt ordering | |
call_fun_set_val_c_7: stmt_6 // interface lifetime bound for fun_set_val (c: return) | |
stmt_6: borrow_x_4 // all prior borrows of x must end | |
borrow_x_8: stmt_6 // arg must be available at use (fun_set_val(z=arg, ...)) | |
tmpx: call_fun_set_val_c_7 // assignment may shorten lifetime | |
stmt_9: stmt_6 // stmt ordering | |
stmt_9: borrow_y_5 // all prior borrows of y must end | |
borrow_y_10: stmt_9 // arg must be available at use (return) | |
return: borrow_x_8 // borrow of local var x ends before end of function | |
return: borrow_y_10 // borrow of local var y ends before end of function | |
return: stmt_9 // stmt ordering | |
return: begin // provided by signature | |
return: begin // required from signature (OK) | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment