Created
November 1, 2016 01:57
-
-
Save Dessix/a17d9a9cb116d2a645620e092d184125 to your computer and use it in GitHub Desktop.
Crazy attempt at heapsort in Scala- Also, Scala's generics are finicky when it comes to upper type bounds
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scala.annotation.tailrec | |
import scala.collection.mutable.ListBuffer | |
import scala.language.postfixOps | |
import scala.math._ | |
import scala.math.Ordering._ | |
import scala.math.Ordered._ | |
//MIN HEAP | |
def makeMutable[T](in: Seq[T]): ListBuffer[T] = { | |
ListBuffer[T](in: _*) | |
} | |
def heapify[T](in: Seq[T])(implicit ev: Ordering[T]): List[T] = { | |
heapifyMut(makeMutable(in)) | |
} | |
case class NodeAddr(addr: Int) { | |
def -(offset: Int) = NodeAddr(addr - offset) | |
def +(offset: Int) = NodeAddr(addr + offset) | |
def -- = this.-(1) | |
def ++ = this.+(1) | |
def >(i: Int) = this.addr.<(i) | |
def <(i: Int) = this.addr.<(i) | |
def <=(i: Int) = this.addr.<=(i) | |
def >=(i: Int) = this.addr.>=(i) | |
def ==(i: Int) = this.addr.==(i) | |
def >(i: NodeAddr) = this.addr.<(i.addr) | |
def <(i: NodeAddr) = this.addr.<(i.addr) | |
def <=(i: NodeAddr) = this.addr.<=(i.addr) | |
def >=(i: NodeAddr) = this.addr.>=(i.addr) | |
def ==(i: NodeAddr) = this.addr.==(i.addr) | |
} | |
//Zero indexed heap parent | |
def parentAddressZ(childAddress: NodeAddr): NodeAddr = NodeAddr((childAddress.addr - 1) / 2) | |
def leftChildAddressZ(parentAddress: NodeAddr): NodeAddr = NodeAddr(parentAddress.addr * 2 + 1) | |
def rightChildAddressZ(parentAddress: NodeAddr): NodeAddr = leftChildAddressZ(parentAddress) + 1 | |
type MutableHeap[T] = ListBuffer[T] | |
def heapifyMut[T](heap: MutableHeap[T])(implicit ev: Ordering[T]): List[T] = { | |
if (heap.length <= 1) return heap.toList | |
for (node <- heap.indices.reverse.map(NodeAddr)) { | |
siftDownMut(heap, node, heap.length) | |
} | |
heap.toList | |
} | |
def nodeAt[T](node: NodeAddr, heap: MutableHeap[T]) = heap(node.addr) | |
@tailrec | |
def siftDownMut[T](heap: MutableHeap[T], node: NodeAddr, len: Int)(implicit ev: Ordering[T]): Unit = { | |
def swap(n: NodeAddr, n2: NodeAddr): Unit = { | |
val nt = v(n) | |
heap(n.addr) = v(n2) | |
heap(n2.addr) = nt | |
} | |
def v(n: NodeAddr) = nodeAt(n, heap) | |
val lc = leftChildAddressZ(node) | |
val rc = rightChildAddressZ(node) | |
val n = v(node) | |
if (lc >= len) return | |
val l = v(lc) | |
val nodeToSwap = | |
if (rc < len) { | |
val r = v(rc) | |
if (n > l || n > r) { | |
Some(if (l >= r) rc else lc) | |
} else { | |
None | |
} | |
} else if (lc < len) { | |
if (l < n) { | |
//swap l and n, recurse on lc | |
Some(lc) | |
} else { | |
None | |
} | |
} else { | |
None | |
} | |
nodeToSwap match { | |
case Some(nts) => | |
swap(node, nts) | |
siftDownMut(heap, nts, len) | |
case _ => Unit | |
} | |
} | |
def heapSort(data: List[Int]): List[Int] = { | |
if (data.length <= 1) return data | |
val heap = makeMutable(heapify(data)) | |
println(heap.length) | |
for (end <- (1 until heap.length).reverse) { | |
val min = heap.head | |
heap(0) = heap(end) | |
heap(end) = min | |
siftDownMut(heap, NodeAddr(0), end) | |
//println(end) | |
} | |
heap.toList | |
} | |
val sampleLength = 10 | |
val sampleData = makeMutable((0 until sampleLength).map(r => util.Random.nextInt(sampleLength))) | |
val heapified = heapify(sampleData) | |
for (i <- heapified.indices) { | |
val nv = heapified(i) | |
val left = leftChildAddressZ(NodeAddr(i)) | |
val right = rightChildAddressZ(NodeAddr(i)) | |
if (right < heapified.length) { | |
val (lv, rv) = (heapified(left.addr), heapified(right.addr)) | |
//println(s"${heapified(i)}: ${lv} / ${rv}") | |
assert(lv >= nv) | |
assert(rv >= nv) | |
} else if (left < heapified.length) { | |
val lv = heapified(left.addr) | |
//println(s"${heapified(i)}: ${heapified(left.addr)}") | |
assert(lv >= nv) | |
} | |
} | |
heapSort(sampleData.toList) | |
//sampleLength: Int = 10 | |
//sampleData: scala.collection.mutable.ListBuffer[Int] = ListBuffer(6, 0, 4, 1, 4, 0, 3, 6, 0, 9) | |
//heapified: List[Int] = List(0, 0, 3, 0, 4, 4, 6, 6, 1, 9) | |
//10 | |
//res1: List[Int] = List(9, 6, 6, 4, 4, 3, 1, 0, 0, 0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment