-
-
Save timothyklim/f1b7e842754ae6e821c8f0aa9ece6df8 to your computer and use it in GitHub Desktop.
Implementation of "Purely Functional Random Access Lists" by Chris Okasaki in scala. This gives O(1) cons and uncons, and 2 log_2 N lookup.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.bykn.list | |
import cats.Applicative | |
import cats.implicits._ | |
/** | |
* Implementation of "Purely Functional Random Access Lists" by Chris Okasaki. | |
* This gives O(1) cons and uncons, and 2 log_2 N lookup. | |
*/ | |
sealed abstract class TreeList[+A] { | |
def uncons: Option[(A, TreeList[A])] | |
def cons[A1 >: A](a1: A1): TreeList[A1] | |
def get(idx: Long): Option[A] | |
def size: Long | |
def foldLeft[B](init: B)(fn: (B, A) => B): B | |
def foldRight[B](fin: B)(fn: (A, B) => B): B | |
def map[B](fn: A => B): TreeList[B] | |
def drop(n: Long): TreeList[A] | |
/** | |
* Split the list roughly in half | |
*/ | |
def split: (TreeList[A], TreeList[A]) | |
def ::[A1 >: A](a1: A1): TreeList[A1] = cons(a1) | |
override def toString: String = { | |
val strb = new java.lang.StringBuilder | |
strb.append("TreeList(") | |
def loop(first: Boolean, l: TreeList[A]): Unit = | |
l.uncons match { | |
case None => () | |
case Some((h, t)) => | |
if (!first) strb.append(", ") | |
strb.append(h.toString) | |
loop(false, t) | |
} | |
loop(true, this) | |
strb.append(")") | |
strb.toString | |
} | |
} | |
object TreeList { | |
sealed trait Nat { | |
def value: Int | |
} | |
sealed abstract class NatEq[A <: Nat, B <: Nat] { | |
def subst[F[_ <: Nat]](f: F[A]): F[B] | |
} | |
object NatEq { | |
implicit def refl[A <: Nat]: NatEq[A, A] = | |
new NatEq[A, A] { | |
def subst[F[_ <: Nat]](f: F[A]): F[A] = f | |
} | |
} | |
object Nat { | |
case class Succ[P <: Nat](prev: P) extends Nat { | |
val value: Int = prev.value + 1 | |
} | |
case object Zero extends Nat { | |
def value: Int = 0 | |
} | |
def maybeEq[N1 <: Nat, N2 <: Nat](n1: N1, n2: N2): Option[NatEq[N1, N2]] = | |
// I don't see how to prove this in scala, but it is true | |
if (n1.value == n2.value) Some(NatEq.refl[N1].asInstanceOf[NatEq[N1, N2]]) | |
else None | |
} | |
sealed abstract class Tree[+N <: Nat, +A] { | |
def value: A | |
def depth: N | |
def size: Long // this is 2^(depth + 1) - 1 | |
def get(idx: Long): Option[A] | |
def map[B](fn: A => B): Tree[N, B] | |
def foldRight[B](fin: B)(fn: (A, B) => B): B | |
} | |
case class Root[A](value: A) extends Tree[Nat.Zero.type, A] { | |
def depth: Nat.Zero.type = Nat.Zero | |
def size = 1L | |
def get(idx: Long): Option[A] = | |
if(idx == 0L) Some(value) else None | |
def map[B](fn: A => B) = Root(fn(value)) | |
def foldRight[B](fin: B)(fn: (A, B) => B): B = fn(value, fin) | |
} | |
case class Balanced[N <: Nat, A](value: A, left: Tree[N, A], right: Tree[N, A]) extends Tree[Nat.Succ[N], A] { | |
val depth: Nat.Succ[N] = Nat.Succ(left.depth) | |
val size = 1L + left.size + right.size | |
def get(idx: Long): Option[A] = | |
if (idx == 0L) Some(value) | |
else if (idx <= left.size) left.get(idx - 1) | |
else right.get(idx - (left.size + 1)) | |
def map[B](fn: A => B) = Balanced[N, B](fn(value), left.map(fn), right.map(fn)) | |
def foldRight[B](fin: B)(fn: (A, B) => B): B = { | |
val rightB = right.foldRight(fin)(fn) | |
val leftB = left.foldRight(rightB)(fn) | |
fn(value, leftB) | |
} | |
} | |
def traverseTree[F[_]: Applicative, A, B, N <: Nat](ta: Tree[N, A], fn: A => F[B]): F[Tree[N, B]] = | |
ta match { | |
case Root(a) => fn(a).map(Root(_)) | |
case Balanced(a, left, right) => | |
(fn(a), traverseTree(left, fn), traverseTree(right, fn)).mapN { (b, l, r) => | |
Balanced(b, l, r) | |
} | |
} | |
private case class Trees[A](treeList: List[Tree[Nat, A]]) extends TreeList[A] { | |
def cons[A1 >: A](a1: A1): TreeList[A1] = | |
treeList match { | |
case h1 :: h2 :: rest => | |
def go[N1 <: Nat, N2 <: Nat, A2 <: A](t1: Tree[N1, A2], t2: Tree[N2, A2]): TreeList[A1] = | |
Nat.maybeEq[N1, N2](t1.depth, t2.depth) match { | |
case Some(eqv) => | |
type T[N <: Nat] = Tree[N, A2] | |
Trees(Balanced[N2, A1](a1, eqv.subst[T](t1), t2) :: rest) | |
case None => | |
Trees(Root(a1) :: treeList) | |
} | |
go(h1, h2) | |
case lessThan2 => Trees(Root(a1) :: lessThan2) | |
} | |
def uncons: Option[(A, TreeList[A])] = | |
treeList match { | |
case Nil => None | |
case Root(a) :: rest => Some((a, Trees(rest))) | |
case Balanced(a, l, r) :: rest => Some((a, Trees(l :: r :: rest))) | |
} | |
def get(idx: Long): Option[A] = { | |
@annotation.tailrec | |
def loop(idx: Long, treeList: List[Tree[Nat, A]]): Option[A] = | |
if (idx < 0L) None | |
else | |
treeList match { | |
case Nil => None | |
case h :: tail => | |
if (h.size <= idx) loop(idx - h.size, tail) | |
else h.get(idx) | |
} | |
loop(idx, treeList) | |
} | |
def size: Long = { | |
@annotation.tailrec | |
def loop(treeList: List[Tree[Nat, A]], acc: Long): Long = | |
treeList match { | |
case Nil => acc | |
case h :: tail => loop(tail, acc + h.size) | |
} | |
loop(treeList, 0L) | |
} | |
def foldLeft[B](init: B)(fn: (B, A) => B): B = { | |
@annotation.tailrec | |
def loop(init: B, rest: List[Tree[Nat, A]]): B = | |
rest match { | |
case Nil => init | |
case Root(a) :: tail => loop(fn(init, a), tail) | |
case Balanced(a, l, r) :: rest => loop(fn(init, a), l :: r :: rest) | |
} | |
loop(init, treeList) | |
} | |
def foldRight[B](fin: B)(fn: (A, B) => B): B = | |
treeList.reverse.foldLeft(fin) { (b, treea) => | |
treea.foldRight(b)(fn) | |
} | |
def map[B](fn: A => B) = Trees(treeList.map(_.map(fn))) | |
def drop(n: Long): TreeList[A] = { | |
@annotation.tailrec | |
def loop(n: Long, treeList: List[Tree[Nat, A]]): TreeList[A] = | |
treeList match { | |
case Nil => empty | |
case _ if n == 0L => Trees(treeList) | |
case h :: tail => | |
if (h.size <= n) loop(n - h.size, tail) | |
else { | |
h match { | |
case Root(_) => | |
loop(n - 1, tail) | |
case Balanced(a, l, r) => | |
if (n > l.size + 1L) loop(n - l.size - 1L, r :: tail) | |
else if (n > 1L) loop(n - 1L, l :: r :: tail) | |
else Trees(l :: r :: tail) | |
} | |
} | |
} | |
loop(n, treeList) | |
} | |
def split: (TreeList[A], TreeList[A]) = | |
treeList match { | |
case Nil => (empty, empty) | |
case Root(_) :: Nil => (this, empty) | |
case Balanced(a, l, r) :: Nil => (Trees(Root(a) :: l :: Nil), Trees(r :: Nil)) | |
case moreThanOne => (Trees(moreThanOne.init), Trees(moreThanOne.last :: Nil)) | |
} | |
} | |
implicit class InvariantTreeList[A](val treeList: TreeList[A]) extends AnyVal { | |
def traverse[F[_]: Applicative, B](fn: A => F[B]): F[TreeList[B]] = | |
treeList match { | |
case Trees(tls) => tls.traverse { tree => traverseTree(tree, fn) }.map(Trees(_)) | |
} | |
} | |
val empty: TreeList[Nothing] = Trees[Nothing](Nil) | |
def fromList[A](list: List[A]): TreeList[A] = { | |
def loop(rev: List[A], acc: TreeList[A]): TreeList[A] = | |
rev match { | |
case Nil => acc | |
case h :: tail => loop(tail, acc.cons(h)) | |
} | |
loop(list.reverse, empty) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment