Created
June 5, 2014 10:46
-
-
Save edeustace/b4759a17a8097de6c094 to your computer and use it in GitHub Desktop.
Another topological sorter in scala
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| package com.ahum | |
| import scala.annotation.tailrec | |
| object TopologicalSorter { | |
| def sort[T](nodes: (T, Seq[T])*): Seq[(T, Seq[T])] = { | |
| type DepNode = (T, Seq[T]) | |
| @tailrec | |
| def innerSort(raw: Seq[DepNode], acc: Seq[DepNode]): Seq[DepNode] = { | |
| def onEdge(t: DepNode) = { | |
| def depsAreOnEdge(deps: Seq[T]) = { | |
| val names = acc.map(_._1) | |
| deps.forall(names.contains(_)) | |
| } | |
| val (_, deps) = t | |
| deps match { | |
| case Nil => "edge" | |
| case d if depsAreOnEdge(d) => "edge" | |
| case _ => "inner" | |
| } | |
| } | |
| def prettyPrint(nodes: Seq[DepNode]) = { | |
| nodes.map { n => s"${n._1} -> ${n._2.mkString(",")}"}.mkString("\n") | |
| } | |
| if (raw.length == 0) { | |
| acc | |
| } else { | |
| val mapped: Map[String, Seq[DepNode]] = raw.groupBy(onEdge) | |
| val edge = mapped.get("edge").getOrElse { | |
| lazy val msg = | |
| s""" | |
| |Can't find edge of graph with remaining nodes. | |
| |Check that there are no cyclical dependencies: | |
| |${prettyPrint(raw)}""".stripMargin | |
| throw new RuntimeException(msg) | |
| } | |
| innerSort(mapped.get("inner").getOrElse(Seq.empty), acc ++ edge) | |
| } | |
| } | |
| def addNodesForUndefinedDeps(node: DepNode, acc: Seq[DepNode]) = { | |
| val (_, deps) = node | |
| val undefinedDeps = deps.filter { d => | |
| (!nodes.map(_._1).contains(d) && !acc.map(_._1).contains(d)) | |
| } | |
| val newDeps = undefinedDeps.map { s => (s -> Seq.empty)} | |
| acc ++ Seq(node) ++ newDeps | |
| } | |
| val expandedNodes = nodes.foldRight[Seq[DepNode]](Seq())(addNodesForUndefinedDeps) | |
| innerSort(expandedNodes, Seq()) | |
| } | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| package com.ahum | |
| import org.specs2.mutable.Specification | |
| class TopologicalSorterSpec extends Specification { | |
| import TopologicalSorter.sort | |
| "Topological Sorter" should { | |
| "sort throws an error if it can't find an edge" in { | |
| sort(1 -> Seq(2), 2 -> Seq(3,1), 3 -> Seq()) must throwA[RuntimeException] | |
| } | |
| "simple: sort throws an error if it can't find an edge" in { | |
| sort(1 -> Seq(2), 2 -> Seq(1)) must throwA[RuntimeException] | |
| } | |
| "sorts" in { | |
| sort( | |
| 1 -> Seq(2), | |
| 2 -> Seq(3), | |
| 3 -> Seq.empty) === | |
| List((3,List()), (2,List(3)), (1,List(2))) | |
| } | |
| "diamond" in { | |
| sort( | |
| 1 -> Seq(2, 3), | |
| 2 -> Seq(4), | |
| 3 -> Seq(4), | |
| 4 -> Seq.empty) === | |
| List( | |
| 4 -> Seq.empty, | |
| 3 -> Seq(4), | |
| 2 -> Seq(4), | |
| 1 -> Seq(2,3)) | |
| } | |
| "sorts undefined nodes" in { | |
| sort(1 -> Seq(2,3)) === | |
| List( | |
| 2 -> Seq.empty, | |
| 3 -> Seq.empty, | |
| 1 -> Seq(2,3) | |
| ) | |
| } | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment