Skip to content

Instantly share code, notes, and snippets.

@mb64
Created March 13, 2021 00:27
Show Gist options
  • Save mb64/7ec0d944cb5e37cf9f29b2e1dd2e8feb to your computer and use it in GitHub Desktop.
Save mb64/7ec0d944cb5e37cf9f29b2e1dd2e8feb to your computer and use it in GitHub Desktop.
Heap benchmarks in MLton
(** A small benchmark for heap implementations
*
* Task: (mutably) sort an array, using heap sort
* Implementations:
* - BinaryHeap, a simple binary heap implemented with mutable array operations
* - PairingHeap, Okazaki's strict amortized pairing heap
* - SplayHeap, Okazaki's strict amortized splay heap
* - BinomialHeap, Okazaki's strict amortized binomial heap
* - LeftistHeap, Okazaki's strict rank-based leftist heap
*
* Compiled with 'mlton -codegen llvm' and measured with hyperfine, on my
* (noisy) machine the results are:
*
* './heap_bench binary 100000' ran
* 8.20 ± 0.33 times faster than './heap_bench leftist 100000'
* 8.86 ± 0.32 times faster than './heap_bench splay 100000'
* 9.71 ± 0.39 times faster than './heap_bench pairing 100000'
* 13.09 ± 0.51 times faster than './heap_bench binomial 100000'
*
*)
signature SORTER = sig val sort : int array -> unit end
structure BinaryHeap : SORTER =
struct
open Array
fun parent i = (i - 1) div 2
fun left i = 2 * i + 1
fun right i = 2 * i + 2
fun swap i j arr = let val tmp = sub (arr, i)
in update (arr, i, sub (arr, j)); update (arr, j, tmp) end
fun bubbleDown (i: int) (len: int) (arr: int array): unit =
if left i < len then
if sub (arr, i) < sub (arr, left i) then
if right i < len andalso sub (arr, left i) < sub (arr, right i) then
(swap i (right i) arr; bubbleDown (right i) len arr)
else
(swap i (left i) arr; bubbleDown (left i) len arr)
else
if right i < len andalso sub (arr, i) < sub (arr, right i) then
(swap i (right i) arr; bubbleDown (right i) len arr)
else ()
else ()
fun makeHeap arr =
let fun loop 0 = bubbleDown 0 (length arr) arr
| loop i = (bubbleDown i (length arr) arr; loop (i - 1))
in loop (parent (length arr - 1)) end
fun sort (arr: int array): unit =
let fun loop 0 = ()
| loop i = let val x = sub (arr, i)
in ( update (arr, i, sub (arr, 0))
; update (arr, 0, x)
; bubbleDown 0 i arr
; loop (i - 1)) end
in makeHeap arr; loop (length arr - 1) end
end
structure PairingHeap : SORTER =
struct
open Array
datatype Heap = E | T of int * Heap list
fun merge h E = h
| merge E h = h
| merge (h1 as T (x, hs1)) (h2 as T (y, hs2)) =
if x < y then T (x, h2 :: hs1) else T (y, h1 :: hs2)
(* fun insert x h = merge (T (x, [])) h *)
fun one x = T (x, [])
fun two x y = if x < y then T (x, [one y]) else T (y, [one x])
fun mergePairs [] = E
| mergePairs [h] = h
| mergePairs (h1 :: h2 :: hs) = merge (merge h1 h2) (mergePairs hs)
fun next E = raise Empty
| next (T (x, hs)) = (x, mergePairs hs)
fun fromArr arr =
let val max = length arr - 1
fun go i = if i < max then
two (sub (arr, i)) (sub (arr, i + 1)) :: go (i + 2)
else if i = max then
[one (sub (arr, max))]
else []
in mergePairs (go 0) end
fun sort (arr: int array): unit =
let val len = length arr
fun loop i h = if i = len then () else
let val (x,h2) = next h
in update (arr, i, x); loop (i+1) h2 end
in loop 0 (fromArr arr) end
end
structure SplayHeap : SORTER =
struct
datatype Heap = E | T of Heap * int * Heap
fun partition pivot E = (E, E)
| partition pivot (t as T (a, x, b)) =
if x <= pivot then
case b of
E => (t, E)
| T (b1, y, b2) =>
if y <= pivot then
let val (small, big) = partition pivot b2
in (T (T (a, x, b1), y, small), big) end
else
let val (small, big) = partition pivot b1
in (T (a, x, small), T (big, y, b2)) end
else
case a of
E => (E, t)
| T (a1, y, a2) =>
if y <= pivot then
let val (small, big) = partition pivot a2
in (T (a1, y, small), T (big, x, b)) end
else
let val (small, big) = partition pivot a1
in (small, T (big, y, T (a2, x, b))) end
fun insert (x, t) = let val (a, b) = partition x t in T (a, x, b) end
val fromArr = Array.foldl insert E
fun next E = raise Empty
| next (T (E, x, b)) = (x, b)
| next (T (T (E, x, b), y, c)) = (x, T (b, y, c))
| next (T (T (a, x, b), y, c)) =
let val (min, a') = next a
in (min, T (a', x, T (b, y, c))) end
fun sort (arr: int array): unit =
let val len = Array.length arr
fun loop i h = if i = len then () else
let val (x, h') = next h
in Array.update (arr, i, x); loop (i+1) h' end
in loop 0 (fromArr arr) end
end
structure BinomialHeap : SORTER =
struct
datatype Tree = Node of int * int * Tree list
type Heap = Tree list
fun rank (Node (r, _, _)) = r
fun root (Node (_, x, _)) = x
fun link (t1 as Node (r, x1, c1)) (t2 as Node (_, x2, c2)) =
if x1 < x2 then Node (r+1, x1, t2 :: c1)
else Node (r+1, x2, t1 :: c2)
fun insTree t [] = [t]
| insTree t (ts as t' :: ts') =
if rank t < rank t' then t :: ts else insTree (link t t') ts'
fun insert (x, ts) = insTree (Node (0, x, [])) ts
val fromArr = Array.foldl insert []
fun merge ts [] = ts
| merge [] ts = ts
| merge (ts1 as t1 :: ts1') (ts2 as t2 :: ts2') =
if rank t1 < rank t2 then t1 :: merge ts1' ts2
else if rank t2 < rank t1 then t2 :: merge ts1 ts2'
else insTree (link t1 t2) (merge ts1' ts2')
fun removeMinTree [] = raise Empty
| removeMinTree [t] = (t, [])
| removeMinTree (t :: ts) =
let val (t', ts') = removeMinTree ts
in if root t <= root t' then (t, ts) else (t', t :: ts') end
fun next ts =
let val (Node (_, x, ts), ts2) = removeMinTree ts
in (x, merge (rev ts) ts2) end
fun sort (arr: int array): unit =
let val len = Array.length arr
fun loop i h = if i = len then () else
let val (x, h') = next h
in Array.update (arr, i, x); loop (i+1) h' end
in loop 0 (fromArr arr) end
end
structure LeftistHeap : SORTER =
struct
open Array
datatype Heap = E | T of int * int * Heap * Heap
fun rank E = 0
| rank (T (r, _, _, _)) = r
fun makeT x a b = if rank a >= rank b then T (rank b + 1, x, a, b)
else T (rank a + 1, x, b, a)
fun one x = T (1, x, E, E)
fun two x y = if x <= y then T (2, x, one y, E) else T (2, y, one x, E)
fun merge h E = h
| merge E h = h
| merge (h1 as T (_, x, a1, b1)) (h2 as T (_, y, a2, b2)) =
if x <= y then makeT x a1 (merge b1 h2)
else makeT y a2 (merge h1 b2)
fun insert (x, E) = T (1, x, E, E)
| insert (x, h as T (_, y, a, b)) =
if x <= y then makeT x E h
else makeT y a (insert (x, b))
fun pairs (h1 :: h2 :: hs) = merge h1 h2 :: pairs hs
| pairs hs = hs
fun merget [] = E
| merget [h] = h
| merget (h1 :: h2 :: hs) = merget (merge h1 h2 :: pairs hs)
fun fromArr arr =
let val max = length arr - 1
fun go i = if i < max then
two (sub (arr, i)) (sub (arr, i + 1)) :: go (i + 2)
else if i = max then
[one (sub (arr, max))]
else []
in merget (go 0) end
fun next E = raise Empty
| next (T (_, x, a, b)) = (x, merge a b)
fun sort (arr: int array): unit =
let val len = length arr
fun loop i h = if i = len then () else
let val (x, h') = next h
in update (arr, i, x); loop (i+1) h' end
in loop 0 (fromArr arr) end
end
fun makeArray len =
Array.tabulate (len, fn i => Word.toInt (MLton.Random.rand () mod 0w2147483648))
fun assertSorted (arr: int array) =
let fun cmp (x,y) = if x < y then raise Fail "not sorted" else x
in Array.foldl cmp 0 arr end
fun test sort =
let val arr = Array.fromList [103,196,94,64,20,146,192,21,52,141]
fun printArray arr =
( Array.app (fn x => print (" " ^ Int.toString x)) arr
; print "\n")
in ( print "Before:\n"
; printArray arr
; sort arr
; print "After:\n"
; printArray arr) end
fun main () =
let val (sort, len) =
case CommandLine.arguments () of
["binary", len] => (BinaryHeap.sort, len)
| ["pairing", len] => (PairingHeap.sort, len)
| ["splay", len] => (SplayHeap.sort, len)
| ["binomial", len] => (BinomialHeap.sort, len)
| ["leftist", len] => (LeftistHeap.sort, len)
| _ => raise Fail "bad args"
val arr = makeArray (Option.valOf (Int.fromString len))
in ( sort arr
; assertSorted arr
; print ("Starts with " ^ Int.toString (Array.sub (arr, 0)) ^ "\n")) end
val () = main ()
@mb64
Copy link
Author

mb64 commented Mar 13, 2021

The purely functional datastructures perform significantly better for already highly structured input, rather than a random array. For a reverse sorted input, the splay tree surprisingly ends up beating the imperative version!

  './heap_bench splay 100000' ran
    1.21 ± 0.11 times faster than './heap_bench binary 100000'
    2.30 ± 0.23 times faster than './heap_bench binomial 100000'
    2.61 ± 0.25 times faster than './heap_bench pairing 100000'
    3.10 ± 0.29 times faster than './heap_bench leftist 100000'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment