Created
August 26, 2025 00:08
-
-
Save arosien/43ef1771626d83f1b3a02eda1d678ee3 to your computer and use it in GitHub Desktop.
Kruskal's minimum spanning tree graph algorithm using cats-collections
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//> using toolkit typelevel:default | |
//> using dep org.typelevel::cats-collections-core::0.9.10 | |
import cats.* | |
import cats.collections.DisjointSets | |
import cats.data.* | |
import cats.instances.order.* | |
import cats.syntax.all.* | |
/** An undirected edge between two nodes, where the edge has a label. | |
*/ | |
type Edge[N, E] = (N, N, E) | |
/** UG = Undirected Graph | |
* | |
* @tparam N | |
* node type | |
* @tparam E | |
* edge label type | |
* @param nodes | |
* @param edges | |
*/ | |
case class UG[N, E](nodes: Set[N], edges: Set[Edge[N, E]]) | |
def minimumSpanningTree[N: Order, E: Order](ug: UG[N, E]): Set[Edge[N, E]] = | |
/* Kruskal's algorithm: from the input state, update the minimum spanning tree set of edges along with the updated state. | |
* - the state is the set of "connected components": nodes that are reachable from other members of the set. | |
* - the output is the set of edges with minimal "cost" to span the graph. | |
* | |
* The algorithm is represented by `State[S, A]` value, which is essentially a `S => (S, A)` function, where: | |
* - the state type `S` is a `DisjointSets[N]`, and | |
* - the output type `A` is a `Set[Edge[N, E]]`. | |
*/ | |
val kruskal: State[DisjointSets[N], Set[Edge[N, E]]] = | |
ug.edges.toList // for each edge | |
.sortBy(_._3) // in increasing cost order | |
// execute an action... | |
.foldM( | |
Set.empty // ...with an accumulator for "shortest" edges | |
): | |
case (mst, edge @ (from, to, cost)) => | |
// each (accumulator, edge) produces a stateful action: | |
for | |
fromComponent <- DisjointSets.find(from) | |
toComponent <- DisjointSets.find(to) | |
mst <- | |
// when components are disconnected | |
if (fromComponent != toComponent) then | |
DisjointSets | |
.union(from, to) // connect them | |
.as(mst + edge) // and accumulate the edge | |
else | |
// otherwise the edge would make a cycle, so "skip" it by returning the unmodified edge set | |
State.pure(mst) | |
yield mst | |
// initial state: every node is its own (disconnected) component | |
val initialState = DisjointSets(ug.nodes.toSeq*) | |
// run the algorithm given the initial state and return the final value, discarding the final state. | |
kruskal.runA(initialState).value | |
/** Example from https://en.wikipedia.org/wiki/Disjoint-set_data_structure#Applications. | |
* | |
* a--1--c | |
* | / | \ | |
* 3 4 6 7 | |
* |/ | \ | |
* b--5--d--2--e | |
* | |
* minimum spanning tree is: a -> c, a -> b, b -> d, d -> e | |
*/ | |
val ug = | |
UG( | |
Set("a", "b", "c", "d", "e"), | |
Set( | |
("a", "b", 3), | |
("a", "c", 1), | |
("b", "c", 4), | |
("b", "d", 5), | |
("c", "d", 6), | |
("c", "e", 7), | |
("d", "e", 2) | |
) | |
) | |
println(minimumSpanningTree(ug)) // Set((a,c,1), (d,e,2), (a,b,3), (b,d,5)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment