Created
April 21, 2020 22:39
-
-
Save catap/8dd82bf4e447c043ef83e5c9d997e2dc to your computer and use it in GitHub Desktop.
Scala implementation of SplayTree
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// This implementation based on: | |
// - https://doi.org/10.1145/3828.3835 | |
sealed trait SplayTree[+K, V] { | |
val left: SplayTree[K, V] | |
def key: K | |
def value: V | |
val right: SplayTree[K, V] | |
val isEmpty: Boolean | |
val size: Int | |
val top: Option[(K, V)] | |
def pop: (K, SplayTree[K, V]) | |
def merge[K1 >: K](that: SplayTree[K1, V])(implicit ord: Ordering[K1]): SplayTree[K1, V] | |
def insert[K1 >: K](key: K1, value: V)(implicit ord: Ordering[K1]): SplayTree[K1, V] | |
def cut[K1 >: K](pivot: K1)(implicit ord: Ordering[K1]): (SplayTree[K1, V], SplayTree[K1, V]) | |
def splay[K1 >: K](key: K1)(implicit ord: Ordering[K1]): SplayTree[K1, V] | |
def find[K1 >: K](key: K1)(implicit ord: Ordering[K1]): Option[V] | |
def remove[K1 >: K](key: K1)(implicit ord: Ordering[K1]): SplayTree[K1, V] | |
} | |
sealed case class SplayTreeEmpty[V]() extends SplayTree[Nothing, V] { | |
val left: SplayTree[Nothing, V] = this | |
def key: Nothing = throw new IllegalArgumentException("empty") | |
def value: V = throw new IllegalArgumentException("empty") | |
val right: SplayTree[Nothing, V] = this | |
val isEmpty: Boolean = true | |
val size: Int = 0 | |
val top: Option[Nothing] = None | |
def pop: (Nothing, SplayTree[Nothing, V]) = | |
throw new IllegalArgumentException("empty") | |
override def merge[K1 >: Nothing](that: SplayTree[K1, V])(implicit ord: Ordering[K1]): SplayTree[K1, V] = | |
that | |
override def insert[K1 >: Nothing](key: K1, value: V)(implicit ord: Ordering[K1]): SplayTree[K1, V] = | |
SplayTreeNode(SplayTreeEmpty(), key, value, SplayTreeEmpty()) | |
override def cut[K1 >: Nothing](pivot: K1)(implicit ord: Ordering[K1]): (SplayTree[K1, V], SplayTree[K1, V]) = | |
(SplayTreeEmpty(), SplayTreeEmpty()) | |
override def splay[K1 >: Nothing](key: K1)(implicit ord: Ordering[K1]): SplayTree[K1, V] = | |
this | |
override def find[K1 >: Nothing](key: K1)(implicit ord: Ordering[K1]): Option[V] = None | |
override def remove[K1 >: Nothing](key: K1)(implicit ord: Ordering[K1]): SplayTree[K1, V] = | |
this | |
} | |
sealed case class SplayTreeNode[K : Ordering, V](left: SplayTree[K, V], key: K, value: V, right: SplayTree[K, V]) extends SplayTree[K, V] { | |
val isEmpty: Boolean = false | |
val size: Int = left.size + 1 + right.size | |
val top: Option[(K, V)] = | |
if (left.isEmpty) Some((key, value)) | |
else left.top | |
val pop: (K, SplayTree[K, V]) = { | |
if (left.isEmpty) (key, right) | |
else if (left.left.isEmpty) (left.key, SplayTreeNode(left.right, key, value, right)) | |
else { | |
val (min, tree) = left.left.pop | |
(min, SplayTreeNode(tree, left.key, left.value, SplayTreeNode(left.right, key, value, right))) | |
} | |
} | |
override def merge[K1 >: K](that: SplayTree[K1, V])(implicit ord: Ordering[K1]): SplayTree[K1, V] = { | |
val (thatLeft, thatRight) = that cut key | |
SplayTreeNode(left merge thatLeft, key, value, right merge thatRight) | |
} | |
override def insert[K1 >: K](newKey: K1, newValue: V)(implicit ord: Ordering[K1]): SplayTree[K1, V] = | |
if (ord.gt(key, newKey)) SplayTreeNode(left insert(newKey, newValue), key, value, right) splay newKey | |
else SplayTreeNode(left, key, value, right insert(newKey, newValue)) splay newKey | |
override def cut[K1 >: K](pivot: K1)(implicit ord: Ordering[K1]): (SplayTree[K1, V], SplayTree[K1, V]) = | |
ord.lt(key, pivot) match { | |
case true if right.isEmpty => | |
(this, SplayTreeEmpty()) | |
case true if ord.lt(right.key, pivot) => | |
val (small, big) = right.right cut pivot | |
(SplayTreeNode(SplayTreeNode(left, key, value, right.left), right.key, right.value, small), big) | |
case true => | |
val (small, big) = right.left cut pivot | |
(SplayTreeNode(left, key, value, small), SplayTreeNode(big, right.key, right.value, right.right)) | |
case false if left.isEmpty => | |
(SplayTreeEmpty(), this) | |
case false if ord.lt(pivot, left.key) => | |
val (small, _) = left.left cut pivot | |
(small, SplayTreeNode(left.left, left.key, left.value, SplayTreeNode(left.right, key, value, right))) | |
case false => | |
val (small, big) = left.right cut pivot | |
(SplayTreeNode(left.left, left.key, left.value, small), SplayTreeNode(big, key, value, right)) | |
} | |
override def splay[K1 >: K](key: K1)(implicit ord: Ordering[K1]): SplayTree[K1, V] = { | |
(left, right) match { | |
// zig-zig | |
case (SplayTreeNode(SplayTreeNode(lll, llk, llv, llr), lk, lv, lr), _) if ord.eq(llk, key) => | |
SplayTreeNode(lll, llk, llv, SplayTreeNode(llr, lk, lv, SplayTreeNode(lr, key, value, right))) | |
case (_, SplayTreeNode(rl, rk, rv, SplayTreeNode(rrl, rrk, rrv, rrr))) if ord.eq(rrk, key) => | |
SplayTreeNode(SplayTreeNode(SplayTreeNode(left, key, value, rl), rk, rv, rrl), rrk, rrv, rrr) | |
// zig-zag | |
case (SplayTreeNode(ll, lk, lv, SplayTreeNode(lrl, lrk, lrv, lrr)), _ ) if ord.eq(lrk, key) => | |
SplayTreeNode(SplayTreeNode(ll, lk, lv, lrl), lrk, lrv, SplayTreeNode(lrr, key, value, right)) | |
case (_, SplayTreeNode(SplayTreeNode(rll, rlk, rlv, rlr), rk, rv, rr)) if ord.eq(rk, key) => | |
SplayTreeNode(SplayTreeNode(left, key, value, rll), rlk, rlv, SplayTreeNode(rlr, rk, rv, rr)) | |
// zig | |
case (SplayTreeNode(ll, lk, lv, lr), _) if ord.eq(lk, key) => | |
SplayTreeNode(ll, lk, lv, SplayTreeNode(lr, key, value, right)) | |
case (_, SplayTreeNode(rl, rk, rv, rr)) if ord.eq(rk, key) => | |
SplayTreeNode(SplayTreeNode(left, key, value, rl), rk, rv, rr) | |
case _ => | |
this | |
} | |
} | |
override def find[K1 >: K](searchKey: K1)(implicit ord: Ordering[K1]): Option[V] = | |
ord.compare(searchKey, key) match { | |
case 0 => | |
Some(value) | |
case x if x < 0 => | |
left.find(searchKey) | |
case _ => | |
right.find(searchKey) | |
} | |
override def remove[K1 >: K](removeKey: K1)(implicit ord: Ordering[K1]): SplayTree[K1, V] = { | |
ord.compare(removeKey, key) match { | |
case 0 => | |
left merge right | |
case x if x < 0 => | |
SplayTreeNode(left remove removeKey, key, value, right) | |
case _ => | |
SplayTreeNode(left, key, value, right remove removeKey) | |
} | |
} | |
} | |
object SplayTree { | |
def empty[A: Ordering, V]: SplayTree[A, V] = SplayTreeEmpty[V]() | |
def apply[A: Ordering, V](xs: Map[A, V]): SplayTree[A, V] = | |
xs.foldLeft(empty[A, V]) { | |
case (tree, (k, v)) => tree insert (k, v) | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.scalatest._ | |
class SplayTreeTest extends WordSpec with Matchers { | |
"SplayTree" should { | |
"works at simple case" in { | |
val map = (-666 to 666).map { k => | |
(k, k.toString) | |
}.toMap | |
val tree = SplayTree(map) | |
tree.size should be(map.size) | |
map.foreach { case (key, value) => | |
tree.find(key) should be(Some(value)) | |
} | |
tree.find(777) should be(None) | |
map.foreach { case (key, _) => | |
val removedTree = tree.remove(key) | |
map.filterNot(_._1 == key).foreach { case (sKey, value) => | |
removedTree.find(sKey) should be(Some(value)) | |
} | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment