Last active
April 13, 2016 18:22
-
-
Save jayhutfles/6b2da3ae9df741971b60 to your computer and use it in GitHub Desktop.
GHS MinSpanningTree algorithm. Not as clean as Prim's or Kruskal's. But it scales, dood.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Let a "fragment" be a connected subgraph of a minimum spanning tree. | |
// To calculate a graph's entire minimum spanning tree: | |
// "while there are mutually nearest fragments, merge them" | |
// "merge": given fragments F and G: | |
// combine F and G's vertices | |
// combine F and G's mst edges, plus their shared external edge | |
// combine external edges, but filter out those which only connect F and G | |
// take max of levels if different, or add one if same | |
// "nearest fragment": given a fragment F | |
// then the fragment containing the other end of F's shortest external edge is the nearest fragment | |
// "mutually nearest": given a fragment F and it's nearest fragment NF: | |
// when NF's nearest fragment is F | |
// then F and NF are a set of mutually nearest fragments | |
// a "fragment" is required to have: | |
// vertices it covers | |
// mst edges joining the vertices | |
// external edges | |
// a level | |
import scala.annotation.tailrec | |
object MinSpanningTree { | |
type Vertex = Long | |
case class Edge(id: Long, vertices: Set[Vertex], length: Double) extends Ordered[Edge] { | |
def compare(that: Edge): Int = if (this.length == that.length) this.id.compare(that.id) else this.length.compare(that.length) | |
} | |
case class Fragment(vertices: Set[Vertex], mstEdges: Set[Edge], externalEdges: Set[Edge], level: Int) { | |
def shortestExternalEdge: Edge = externalEdges.min | |
def nearestVertex: Vertex = { | |
val nearestVertices = shortestExternalEdge.vertices.diff(vertices) | |
assert(nearestVertices.size == 1) | |
nearestVertices.head | |
} | |
} | |
object Fragment { | |
def merge(f: Fragment, g: Fragment): Fragment = { | |
val newVertices = f.vertices ++ g.vertices | |
val newMSTEdges = sharedExternalEdge(f, g) match { | |
case Some(edge) => f.mstEdges ++ g.mstEdges + edge | |
case None => f.mstEdges ++ g.mstEdges | |
} | |
val newExtEdges = (f.externalEdges ++ g.externalEdges).filter(e => newVertices.intersect(e.vertices).size == 1) | |
val newLevel = if (f.level == g.level) (f.level + 1) else Math.max(f.level, g.level) | |
new Fragment(newVertices, newMSTEdges, newExtEdges, newLevel) | |
} | |
def sharedExternalEdge(f: Fragment, g: Fragment): Option[Edge] = { | |
if (f.shortestExternalEdge == g.shortestExternalEdge) Some(f.shortestExternalEdge) else None | |
} | |
} | |
object MSTFragments { | |
def fragmentsBasedOn(edges: Set[Edge]): Set[Fragment] = { | |
val allVertices: Set[Vertex] = edges.flatMap(e => e.vertices) | |
for { | |
v <- allVertices | |
} yield { | |
val edgesAdjacentToV = edges.filter(e => e.vertices.contains(v)) | |
Fragment(Set(v), Set[Edge](), edgesAdjacentToV, 0) | |
} | |
} | |
def createFragmentLookupMapFrom(fragments: Set[Fragment]): Map[Vertex, Fragment] = { | |
(for { | |
f <- fragments | |
v <- f.vertices | |
} yield { | |
(v, f) | |
}).toMap | |
} | |
def mutuallyNearestFragmentsIn(fragments: Set[Fragment]) = { | |
val lookupFragmentForVertex = createFragmentLookupMapFrom(fragments) | |
val fragmentsAndNeighbors = for { | |
f <- fragments | |
} yield { | |
Set(f, lookupFragmentForVertex(f.nearestVertex)) | |
} | |
fragmentsAndNeighbors.filter(fragmentAndNeighbor => { | |
val fragment = fragmentAndNeighbor.head | |
val neighbor = fragmentAndNeighbor.last | |
fragment.vertices.contains(neighbor.nearestVertex) | |
}) | |
} | |
@tailrec | |
def mergeMutuallyNearest(fragments: Set[Fragment]): Set[Fragment] = { | |
lazy val mutuallyNearestFragments = mutuallyNearestFragmentsIn(fragments) | |
if (fragments.size < 2 || mutuallyNearestFragments.size == 0) { | |
fragments | |
} else { | |
val unmergedFragments = fragments.diff(mutuallyNearestFragments.flatten) | |
val mergedFragments = mutuallyNearestFragments.map(toMerge => Fragment.merge(toMerge.head, toMerge.last)) | |
mergeMutuallyNearest(mergedFragments ++ unmergedFragments) | |
} | |
} | |
def apply(initialEdges: Set[Edge]): Set[Fragment] = { | |
mergeMutuallyNearest(fragmentsBasedOn(initialEdges)) | |
} | |
} | |
def apply(edges: Set[Edge]): Set[Edge] = { | |
for { | |
fragment <- MSTFragments(edges) | |
mstEdge <- fragment.mstEdges | |
} yield { | |
mstEdge | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.scalatest.FunSuite | |
import MinSpanningTree._ | |
class MinSpanningTreeTests extends FunSuite { | |
test("ALGEBRAIC: merging fragments with inequal levels results in correct level") { | |
// GIVEN two Fragments with inequal levels | |
val e12 = Edge(12, Set(1, 2), 1.0) | |
val e13 = Edge(13, Set(1, 3), 2.0) | |
val e23 = Edge(23, Set(2, 3), 3.0) | |
val f1 = Fragment(Set(1), Set[Edge](), Set(e12, e13), 0) | |
val f2 = Fragment(Set(2, 3), Set(e23), Set(e12, e13), 1) | |
// WHEN merging them | |
val f3 = Fragment.merge(f1, f2) | |
// THEN the level is the max of the two levels | |
assert(f3.level === Math.max(f1.level, f2.level)) | |
} | |
test("ALGEBRAIC: merging fragments with equal levels results in correct level") { | |
// GIVEN two Fragments with equal levels | |
val e12 = Edge(12, Set(1, 2), 1.0) | |
val e13 = Edge(13, Set(1, 3), 2.0) | |
val e23 = Edge(23, Set(2, 3), 3.0) | |
val f1 = Fragment(Set(1), Set[Edge](), Set(e12, e13), 0) | |
val f2 = Fragment(Set(2), Set[Edge](), Set(e12, e23), 0) | |
// WHEN merging them | |
val f3 = Fragment.merge(f1, f2) | |
// THEN the level is the max of the two levels | |
assert(f3.level === f1.level + 1) | |
assert(f3.level === f2.level + 1) | |
} | |
test("ALGEBRAIC: merging fragments filters out external edges which would otherwise cause a cycle") { | |
// GIVEN two Fragments whose merge would result in cyclic edges | |
val e12 = Edge(12, Set(1, 2), 1.0) | |
val e13 = Edge(13, Set(1, 3), 2.0) | |
val e14 = Edge(14, Set(1, 4), 4.0) | |
val e23 = Edge(23, Set(2, 3), 3.0) | |
val f1 = Fragment(Set(1), Set[Edge](), Set(e12, e13, e14), 0) | |
val f2 = Fragment(Set(2, 3), Set(e23), Set(e12, e13), 1) | |
// WHEN merging them | |
val f3 = Fragment.merge(f1, f2) | |
// THEN the edge e13 is not an external edge, but e14 is | |
assert(!f3.externalEdges.contains(e13)) | |
assert(f3.externalEdges.contains(e14)) | |
} | |
test("ALGEBRAIC: merging fragments results in one more MSTEdge than before") { | |
// GIVEN two fragments with at least one shared edge | |
val e12 = Edge(12, Set(1, 2), 1.0) | |
val e14 = Edge(14, Set(1, 4), 4.0) | |
val e23 = Edge(23, Set(2, 3), 3.0) | |
val f1 = Fragment(Set(1), Set[Edge](), Set(e12, e14), 0) | |
val f2 = Fragment(Set(2, 3), Set(e23), Set(e12), 1) | |
// WHEN merging them | |
val f3 = Fragment.merge(f1, f2) | |
// THEN the shared external edge is part of the MSTEdges | |
assert(f3.mstEdges.contains(e12)) | |
} | |
test("INDUCTIVE HYPOTHESIS: mutually nearest neighbors are appropriately identified") { | |
// GIVEN a fragment with two external edges | |
val e12 = Edge(12, Set(1, 2), 1.0) | |
val e13 = Edge(13, Set(1, 3), 2.0) | |
val f1 = Fragment(Set(1), Set[Edge](), Set(e12, e13), 0) | |
val f2 = Fragment(Set(2), Set[Edge](), Set(e12), 0) | |
val f3 = Fragment(Set(3), Set[Edge](), Set(e13), 0) | |
val fragments = Set(f1, f2, f3) | |
// WHEN finding its mutually nearest neighbor | |
val mutuallyNearestNeighbors = MSTFragments.mutuallyNearestFragmentsIn(fragments) | |
// THEN it identifies the correct ones | |
assert(mutuallyNearestNeighbors.contains(Set(f1, f2))) | |
assert(mutuallyNearestNeighbors.size == 1) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
import java.io._
import scala.io.Source
val edgeLines = Source.fromFile("/path/to/file").getLines.toList.map(_.split(",") match {
case Array(id, src, dst, _, length) => Edge(id.toLong, Set(src.toLong, dst.toLong), length.toDouble)
})