Skip to content

Instantly share code, notes, and snippets.

@Dessix
Created November 1, 2016 01:57
Show Gist options
  • Save Dessix/a17d9a9cb116d2a645620e092d184125 to your computer and use it in GitHub Desktop.
Save Dessix/a17d9a9cb116d2a645620e092d184125 to your computer and use it in GitHub Desktop.
Crazy attempt at heapsort in Scala- Also, Scala's generics are finicky when it comes to upper type bounds
import scala.annotation.tailrec
import scala.collection.mutable.ListBuffer
import scala.language.postfixOps
import scala.math._
import scala.math.Ordering._
import scala.math.Ordered._
//MIN HEAP
def makeMutable[T](in: Seq[T]): ListBuffer[T] = {
ListBuffer[T](in: _*)
}
def heapify[T](in: Seq[T])(implicit ev: Ordering[T]): List[T] = {
heapifyMut(makeMutable(in))
}
case class NodeAddr(addr: Int) {
def -(offset: Int) = NodeAddr(addr - offset)
def +(offset: Int) = NodeAddr(addr + offset)
def -- = this.-(1)
def ++ = this.+(1)
def >(i: Int) = this.addr.<(i)
def <(i: Int) = this.addr.<(i)
def <=(i: Int) = this.addr.<=(i)
def >=(i: Int) = this.addr.>=(i)
def ==(i: Int) = this.addr.==(i)
def >(i: NodeAddr) = this.addr.<(i.addr)
def <(i: NodeAddr) = this.addr.<(i.addr)
def <=(i: NodeAddr) = this.addr.<=(i.addr)
def >=(i: NodeAddr) = this.addr.>=(i.addr)
def ==(i: NodeAddr) = this.addr.==(i.addr)
}
//Zero indexed heap parent
def parentAddressZ(childAddress: NodeAddr): NodeAddr = NodeAddr((childAddress.addr - 1) / 2)
def leftChildAddressZ(parentAddress: NodeAddr): NodeAddr = NodeAddr(parentAddress.addr * 2 + 1)
def rightChildAddressZ(parentAddress: NodeAddr): NodeAddr = leftChildAddressZ(parentAddress) + 1
type MutableHeap[T] = ListBuffer[T]
def heapifyMut[T](heap: MutableHeap[T])(implicit ev: Ordering[T]): List[T] = {
if (heap.length <= 1) return heap.toList
for (node <- heap.indices.reverse.map(NodeAddr)) {
siftDownMut(heap, node, heap.length)
}
heap.toList
}
def nodeAt[T](node: NodeAddr, heap: MutableHeap[T]) = heap(node.addr)
@tailrec
def siftDownMut[T](heap: MutableHeap[T], node: NodeAddr, len: Int)(implicit ev: Ordering[T]): Unit = {
def swap(n: NodeAddr, n2: NodeAddr): Unit = {
val nt = v(n)
heap(n.addr) = v(n2)
heap(n2.addr) = nt
}
def v(n: NodeAddr) = nodeAt(n, heap)
val lc = leftChildAddressZ(node)
val rc = rightChildAddressZ(node)
val n = v(node)
if (lc >= len) return
val l = v(lc)
val nodeToSwap =
if (rc < len) {
val r = v(rc)
if (n > l || n > r) {
Some(if (l >= r) rc else lc)
} else {
None
}
} else if (lc < len) {
if (l < n) {
//swap l and n, recurse on lc
Some(lc)
} else {
None
}
} else {
None
}
nodeToSwap match {
case Some(nts) =>
swap(node, nts)
siftDownMut(heap, nts, len)
case _ => Unit
}
}
def heapSort(data: List[Int]): List[Int] = {
if (data.length <= 1) return data
val heap = makeMutable(heapify(data))
println(heap.length)
for (end <- (1 until heap.length).reverse) {
val min = heap.head
heap(0) = heap(end)
heap(end) = min
siftDownMut(heap, NodeAddr(0), end)
//println(end)
}
heap.toList
}
val sampleLength = 10
val sampleData = makeMutable((0 until sampleLength).map(r => util.Random.nextInt(sampleLength)))
val heapified = heapify(sampleData)
for (i <- heapified.indices) {
val nv = heapified(i)
val left = leftChildAddressZ(NodeAddr(i))
val right = rightChildAddressZ(NodeAddr(i))
if (right < heapified.length) {
val (lv, rv) = (heapified(left.addr), heapified(right.addr))
//println(s"${heapified(i)}: ${lv} / ${rv}")
assert(lv >= nv)
assert(rv >= nv)
} else if (left < heapified.length) {
val lv = heapified(left.addr)
//println(s"${heapified(i)}: ${heapified(left.addr)}")
assert(lv >= nv)
}
}
heapSort(sampleData.toList)
//sampleLength: Int = 10
//sampleData: scala.collection.mutable.ListBuffer[Int] = ListBuffer(6, 0, 4, 1, 4, 0, 3, 6, 0, 9)
//heapified: List[Int] = List(0, 0, 3, 0, 4, 4, 6, 6, 1, 9)
//10
//res1: List[Int] = List(9, 6, 6, 4, 4, 3, 1, 0, 0, 0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment