Created
March 13, 2021 00:27
-
-
Save mb64/7ec0d944cb5e37cf9f29b2e1dd2e8feb to your computer and use it in GitHub Desktop.
Heap benchmarks in MLton
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(** A small benchmark for heap implementations | |
* | |
* Task: (mutably) sort an array, using heap sort | |
* Implementations: | |
* - BinaryHeap, a simple binary heap implemented with mutable array operations | |
* - PairingHeap, Okazaki's strict amortized pairing heap | |
* - SplayHeap, Okazaki's strict amortized splay heap | |
* - BinomialHeap, Okazaki's strict amortized binomial heap | |
* - LeftistHeap, Okazaki's strict rank-based leftist heap | |
* | |
* Compiled with 'mlton -codegen llvm' and measured with hyperfine, on my | |
* (noisy) machine the results are: | |
* | |
* './heap_bench binary 100000' ran | |
* 8.20 ± 0.33 times faster than './heap_bench leftist 100000' | |
* 8.86 ± 0.32 times faster than './heap_bench splay 100000' | |
* 9.71 ± 0.39 times faster than './heap_bench pairing 100000' | |
* 13.09 ± 0.51 times faster than './heap_bench binomial 100000' | |
* | |
*) | |
signature SORTER = sig val sort : int array -> unit end | |
structure BinaryHeap : SORTER = | |
struct | |
open Array | |
fun parent i = (i - 1) div 2 | |
fun left i = 2 * i + 1 | |
fun right i = 2 * i + 2 | |
fun swap i j arr = let val tmp = sub (arr, i) | |
in update (arr, i, sub (arr, j)); update (arr, j, tmp) end | |
fun bubbleDown (i: int) (len: int) (arr: int array): unit = | |
if left i < len then | |
if sub (arr, i) < sub (arr, left i) then | |
if right i < len andalso sub (arr, left i) < sub (arr, right i) then | |
(swap i (right i) arr; bubbleDown (right i) len arr) | |
else | |
(swap i (left i) arr; bubbleDown (left i) len arr) | |
else | |
if right i < len andalso sub (arr, i) < sub (arr, right i) then | |
(swap i (right i) arr; bubbleDown (right i) len arr) | |
else () | |
else () | |
fun makeHeap arr = | |
let fun loop 0 = bubbleDown 0 (length arr) arr | |
| loop i = (bubbleDown i (length arr) arr; loop (i - 1)) | |
in loop (parent (length arr - 1)) end | |
fun sort (arr: int array): unit = | |
let fun loop 0 = () | |
| loop i = let val x = sub (arr, i) | |
in ( update (arr, i, sub (arr, 0)) | |
; update (arr, 0, x) | |
; bubbleDown 0 i arr | |
; loop (i - 1)) end | |
in makeHeap arr; loop (length arr - 1) end | |
end | |
structure PairingHeap : SORTER = | |
struct | |
open Array | |
datatype Heap = E | T of int * Heap list | |
fun merge h E = h | |
| merge E h = h | |
| merge (h1 as T (x, hs1)) (h2 as T (y, hs2)) = | |
if x < y then T (x, h2 :: hs1) else T (y, h1 :: hs2) | |
(* fun insert x h = merge (T (x, [])) h *) | |
fun one x = T (x, []) | |
fun two x y = if x < y then T (x, [one y]) else T (y, [one x]) | |
fun mergePairs [] = E | |
| mergePairs [h] = h | |
| mergePairs (h1 :: h2 :: hs) = merge (merge h1 h2) (mergePairs hs) | |
fun next E = raise Empty | |
| next (T (x, hs)) = (x, mergePairs hs) | |
fun fromArr arr = | |
let val max = length arr - 1 | |
fun go i = if i < max then | |
two (sub (arr, i)) (sub (arr, i + 1)) :: go (i + 2) | |
else if i = max then | |
[one (sub (arr, max))] | |
else [] | |
in mergePairs (go 0) end | |
fun sort (arr: int array): unit = | |
let val len = length arr | |
fun loop i h = if i = len then () else | |
let val (x,h2) = next h | |
in update (arr, i, x); loop (i+1) h2 end | |
in loop 0 (fromArr arr) end | |
end | |
structure SplayHeap : SORTER = | |
struct | |
datatype Heap = E | T of Heap * int * Heap | |
fun partition pivot E = (E, E) | |
| partition pivot (t as T (a, x, b)) = | |
if x <= pivot then | |
case b of | |
E => (t, E) | |
| T (b1, y, b2) => | |
if y <= pivot then | |
let val (small, big) = partition pivot b2 | |
in (T (T (a, x, b1), y, small), big) end | |
else | |
let val (small, big) = partition pivot b1 | |
in (T (a, x, small), T (big, y, b2)) end | |
else | |
case a of | |
E => (E, t) | |
| T (a1, y, a2) => | |
if y <= pivot then | |
let val (small, big) = partition pivot a2 | |
in (T (a1, y, small), T (big, x, b)) end | |
else | |
let val (small, big) = partition pivot a1 | |
in (small, T (big, y, T (a2, x, b))) end | |
fun insert (x, t) = let val (a, b) = partition x t in T (a, x, b) end | |
val fromArr = Array.foldl insert E | |
fun next E = raise Empty | |
| next (T (E, x, b)) = (x, b) | |
| next (T (T (E, x, b), y, c)) = (x, T (b, y, c)) | |
| next (T (T (a, x, b), y, c)) = | |
let val (min, a') = next a | |
in (min, T (a', x, T (b, y, c))) end | |
fun sort (arr: int array): unit = | |
let val len = Array.length arr | |
fun loop i h = if i = len then () else | |
let val (x, h') = next h | |
in Array.update (arr, i, x); loop (i+1) h' end | |
in loop 0 (fromArr arr) end | |
end | |
structure BinomialHeap : SORTER = | |
struct | |
datatype Tree = Node of int * int * Tree list | |
type Heap = Tree list | |
fun rank (Node (r, _, _)) = r | |
fun root (Node (_, x, _)) = x | |
fun link (t1 as Node (r, x1, c1)) (t2 as Node (_, x2, c2)) = | |
if x1 < x2 then Node (r+1, x1, t2 :: c1) | |
else Node (r+1, x2, t1 :: c2) | |
fun insTree t [] = [t] | |
| insTree t (ts as t' :: ts') = | |
if rank t < rank t' then t :: ts else insTree (link t t') ts' | |
fun insert (x, ts) = insTree (Node (0, x, [])) ts | |
val fromArr = Array.foldl insert [] | |
fun merge ts [] = ts | |
| merge [] ts = ts | |
| merge (ts1 as t1 :: ts1') (ts2 as t2 :: ts2') = | |
if rank t1 < rank t2 then t1 :: merge ts1' ts2 | |
else if rank t2 < rank t1 then t2 :: merge ts1 ts2' | |
else insTree (link t1 t2) (merge ts1' ts2') | |
fun removeMinTree [] = raise Empty | |
| removeMinTree [t] = (t, []) | |
| removeMinTree (t :: ts) = | |
let val (t', ts') = removeMinTree ts | |
in if root t <= root t' then (t, ts) else (t', t :: ts') end | |
fun next ts = | |
let val (Node (_, x, ts), ts2) = removeMinTree ts | |
in (x, merge (rev ts) ts2) end | |
fun sort (arr: int array): unit = | |
let val len = Array.length arr | |
fun loop i h = if i = len then () else | |
let val (x, h') = next h | |
in Array.update (arr, i, x); loop (i+1) h' end | |
in loop 0 (fromArr arr) end | |
end | |
structure LeftistHeap : SORTER = | |
struct | |
open Array | |
datatype Heap = E | T of int * int * Heap * Heap | |
fun rank E = 0 | |
| rank (T (r, _, _, _)) = r | |
fun makeT x a b = if rank a >= rank b then T (rank b + 1, x, a, b) | |
else T (rank a + 1, x, b, a) | |
fun one x = T (1, x, E, E) | |
fun two x y = if x <= y then T (2, x, one y, E) else T (2, y, one x, E) | |
fun merge h E = h | |
| merge E h = h | |
| merge (h1 as T (_, x, a1, b1)) (h2 as T (_, y, a2, b2)) = | |
if x <= y then makeT x a1 (merge b1 h2) | |
else makeT y a2 (merge h1 b2) | |
fun insert (x, E) = T (1, x, E, E) | |
| insert (x, h as T (_, y, a, b)) = | |
if x <= y then makeT x E h | |
else makeT y a (insert (x, b)) | |
fun pairs (h1 :: h2 :: hs) = merge h1 h2 :: pairs hs | |
| pairs hs = hs | |
fun merget [] = E | |
| merget [h] = h | |
| merget (h1 :: h2 :: hs) = merget (merge h1 h2 :: pairs hs) | |
fun fromArr arr = | |
let val max = length arr - 1 | |
fun go i = if i < max then | |
two (sub (arr, i)) (sub (arr, i + 1)) :: go (i + 2) | |
else if i = max then | |
[one (sub (arr, max))] | |
else [] | |
in merget (go 0) end | |
fun next E = raise Empty | |
| next (T (_, x, a, b)) = (x, merge a b) | |
fun sort (arr: int array): unit = | |
let val len = length arr | |
fun loop i h = if i = len then () else | |
let val (x, h') = next h | |
in update (arr, i, x); loop (i+1) h' end | |
in loop 0 (fromArr arr) end | |
end | |
fun makeArray len = | |
Array.tabulate (len, fn i => Word.toInt (MLton.Random.rand () mod 0w2147483648)) | |
fun assertSorted (arr: int array) = | |
let fun cmp (x,y) = if x < y then raise Fail "not sorted" else x | |
in Array.foldl cmp 0 arr end | |
fun test sort = | |
let val arr = Array.fromList [103,196,94,64,20,146,192,21,52,141] | |
fun printArray arr = | |
( Array.app (fn x => print (" " ^ Int.toString x)) arr | |
; print "\n") | |
in ( print "Before:\n" | |
; printArray arr | |
; sort arr | |
; print "After:\n" | |
; printArray arr) end | |
fun main () = | |
let val (sort, len) = | |
case CommandLine.arguments () of | |
["binary", len] => (BinaryHeap.sort, len) | |
| ["pairing", len] => (PairingHeap.sort, len) | |
| ["splay", len] => (SplayHeap.sort, len) | |
| ["binomial", len] => (BinomialHeap.sort, len) | |
| ["leftist", len] => (LeftistHeap.sort, len) | |
| _ => raise Fail "bad args" | |
val arr = makeArray (Option.valOf (Int.fromString len)) | |
in ( sort arr | |
; assertSorted arr | |
; print ("Starts with " ^ Int.toString (Array.sub (arr, 0)) ^ "\n")) end | |
val () = main () |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The purely functional datastructures perform significantly better for already highly structured input, rather than a random array. For a reverse sorted input, the splay tree surprisingly ends up beating the imperative version!