Created
October 31, 2020 08:42
-
-
Save mb64/06f0dfba4f8eebbfa976dc4399e52b4d to your computer and use it in GitHub Desktop.
Flat trees in ATS and C
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* A flat representation of | |
* data Tree = Leaf Int | Node Tree Tree | |
* | |
* Either: | |
* - *ft is LEAF and a single int follows | |
* - *ft is NODE and two subtrees follow | |
*/ | |
#define LEAF 0 | |
#define NODE 1 | |
typedef int *FlatTree; | |
int tree_size_rec(FlatTree ft) { | |
if (*ft == LEAF) { | |
// a single integer follows | |
return 1 + 1; | |
} else { | |
// the two subtrees follow | |
int sz_l = tree_size_rec(ft + 1); | |
int sz_r = tree_size_rec(ft + 1 + sz_l); | |
return 1 + sz_l + sz_r; | |
} | |
} | |
int tree_size_fast(FlatTree ft) { | |
int *start = ft; | |
int trees_remaining = 1; | |
while (trees_remaining) { | |
int tag = *ft++; | |
if (tag == LEAF) { | |
// skip over the value | |
ft++; | |
trees_remaining--; | |
} else { | |
trees_remaining++; | |
} | |
} | |
return ft - start; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "share/atspre_define.hats" | |
#include "share/atspre_staload.hats" | |
staload UN = "prelude/SATS/unsafe.sats" | |
fn div {a:nat} {b:pos} (c: int (b*a), b: int b):<> int a | |
= $UN.cast{int(a)}(c/b) | |
extern praxi ints_have_size (): [sizeof(int) > 0] void | |
(* A flat representation of | |
* data Tree = Leaf Int | Node Tree Tree | |
* | |
* Either: | |
* - tag is LEAF and a single int follows | |
* - tag is NODE and two subtrees follow | |
*) | |
#define LEAF 0 | |
#define NODE 1 | |
dataview FlatTree(l:addr, s:int) = (* addr, size *) | |
{l:addr} {tag:int | tag == LEAF || tag == NODE} {s:int} | |
FlatTree(l, s + 1) of ( | |
int tag @ l, | |
FlatTreeContents(l + sizeof(int), s, tag) | |
) | |
and FlatTreeContents(l:addr, s:int, tag:int) = (* addr, size, tag *) | |
| {l:addr} Leaf(l, 1, LEAF) of (int @ l) | |
| {l:addr} {s1:nat} {s2:nat} | |
Node(l, s1 + s2, NODE) of ( | |
FlatTree(l, s1), | |
FlatTree(l + s1*sizeof(int), s2) | |
) | |
fun tree_size_rec {l:addr} {s:nat} .<s>. ( | |
pf: !FlatTree(l, s) | ptr: ptr l | |
):<> int s = | |
let | |
prval FlatTree(tagptr, contents) = pf | |
in case+ !ptr of | |
| LEAF => let | |
(* a single integer follows *) | |
prval Leaf(itemptr) = contents | |
prval () = pf := FlatTree(tagptr, Leaf(itemptr)) | |
in 1 + 1 end | |
| NODE => let | |
(* the two subtrees follow *) | |
prval Node(left, right) = contents | |
val left_ptr = ptr_add<int>(ptr, 1) | |
val left_size = tree_size_rec(left | left_ptr) | |
val right_ptr = ptr_add<int>(left_ptr, left_size) | |
val right_size = tree_size_rec(right | right_ptr) | |
prval () = pf := FlatTree(tagptr, Node(left, right)) | |
in 1 + left_size + right_size end | |
end | |
(* `n` trees in a row *) | |
dataview ManyTrees(l:addr, n:int, s:int) = (* addr, count, size *) | |
| {l:addr} NoTrees(l, 0, 0) | |
| {l:addr} {s1:nat} {s2:nat} {n:nat} | |
SomeTrees(l, n + 1, s1 + s2) of ( | |
FlatTree(l, s1), | |
ManyTrees(l + s1*sizeof(int), n, s2) | |
) | |
fun loop {l:addr} {n:nat} {s:nat} .<s>. ( | |
trees: !ManyTrees(l, n, s) | trees_remaining: int n, ptr: ptr l | |
):<> ptr (l + sizeof(int)*s) = | |
if trees_remaining = 0 | |
then let | |
prval NoTrees() = trees | |
prval () = trees := NoTrees() | |
in ptr end | |
else let | |
prval SomeTrees(FlatTree(tagpf, contents), rest) = trees | |
in case+ !ptr of | |
| LEAF => let | |
prval Leaf(itemptr) = contents | |
val result = loop (rest | trees_remaining - 1, ptr_add<int>(ptr, 2)) | |
prval () = trees := SomeTrees(FlatTree(tagpf, Leaf(itemptr)), rest) | |
in result end | |
| NODE => let | |
prval Node(left, right) = contents | |
prval next = SomeTrees(left, SomeTrees(right, rest)) | |
val result = loop (next | trees_remaining + 1, ptr_add<int>(ptr, 1)) | |
prval SomeTrees(left, SomeTrees(right, rest)) = next | |
prval () = trees := SomeTrees(FlatTree(tagpf, Node(left, right)), rest) | |
in result end | |
end | |
fn tree_size_fast {l:addr} {s:nat} ( | |
pf: !FlatTree(l, s) | ptr: ptr l | |
):<> int s = | |
let | |
prval trees = SomeTrees(pf, NoTrees()) | |
val end_ptr = loop (trees | 1, ptr) | |
prval SomeTrees(pf1, NoTrees()) = trees | |
prval () = pf := pf1 | |
prval () = ints_have_size () | |
in div (g1int2int (end_ptr - ptr), g1uint2int (sizeof<int>)) end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment