Skip to content

Instantly share code, notes, and snippets.

@ChunMinChang
Last active June 2, 2017 08:12
Show Gist options
  • Save ChunMinChang/81862ff828866aa095e2f3626370beaf to your computer and use it in GitHub Desktop.
Save ChunMinChang/81862ff828866aa095e2f3626370beaf to your computer and use it in GitHub Desktop.
minimal spanning tree

Minimal Spanning Tree

  • Kruskal: Kruskal.cpp
  • Prim: Prim.cpp

Kruskal is relatively simpler and easy to implement. Prim is similar to Dijkstra. Prim selects the vertex with shortest edge to the forming spanning, while Dijkstra selects the vertex with shortest distance to the source vertex.

The time complexity of Kruskal and Prim are both O(E*logV).

How to compile

Run $ make to compile and $ make clean to the remove executables and objects.

#if !defined(DISJOINTSET_H)
#define DISJOINTSET_H
// The disjoint-set/union–find/merge–find data structure:
// https://en.wikipedia.org/wiki/Disjoint-set_data_structure
class DisjointSet
{
public:
// All numbers belong in different sets at the begining.
DisjointSet(unsigned int aNumbers)
{
for (unsigned int i = 0 ; i < aNumbers ; ++i) {
// The parents of all the nodes are themselves at the begining.
mParents.push_back(i);
// All the ranks are 0 at the begining.
mRanks.push_back(0);
}
}
~DisjointSet() {}
bool AreSameSets(unsigned int aX, unsigned int aY)
{
return FindRoot(aX) == FindRoot(aY);
}
// Unite the set containing u and the set containing v.
void Union(unsigned int aX, unsigned int aY)
{
unsigned int rootX = FindRoot(aX);
unsigned int rootY = FindRoot(aY);
// Do nothing if aX and aY are already in the same set.
if (rootX == rootY) {
return;
}
// By attaching the lower rank tree to the root of higher rank tree,
// the tree will be balanced and then time complexity of the worst case
// of `Union` and `FindRoot` will be O(log n).
if (mRanks[rootX] > mRanks[rootY]) {
mParents[rootY] = rootX;
} else if (mRanks[rootX] < mRanks[rootY]) {
mParents[rootX] = rootY;
} else { // mRanks[rootX] == mRanks[rootY]
mParents[rootY] = rootX;
mRanks[rootX]++;
}
}
private:
unsigned int FindRoot(unsigned int aElement)
{
unsigned int parent = aElement;
while (mParents[parent] != parent) {
parent = mParents[parent];
}
return parent;
}
// Indicating the parent of the number.
std::vector<unsigned int> mParents;
// Indicating the depth of the number's family tree.
std::vector<unsigned int> mRanks;
};
#endif // DISJOINTSET_H
#if !defined(GRAPH_H)
#define GRAPH_H
#include <assert.h> // for assert
#include <vector> // for std::vector
typedef struct Edge
{
Edge(unsigned int aFrom, unsigned int aTo, int aWeight, bool aDirected)
: mFrom(aFrom)
, mTo(aTo)
, mWeight(aWeight)
, mDirected(aDirected)
{
}
~Edge() {}
// Used in std::greater in Prim
bool operator>(const Edge& aOther) const
{
assert(mDirected == aOther.mDirected);
return mWeight > aOther.mWeight;
}
// Used in std::sort in Kruskal
bool operator<(const Edge& aOther) const
{
assert(mDirected == aOther.mDirected);
return mWeight < aOther.mWeight;
}
unsigned int mFrom;
unsigned int mTo;
int mWeight;
// If edge is directed, then the edge is from mFrom to mTo.
// Otherwise, the edge is bi-directional.
// It can start from mFrom to mTo and from mTo to mFrom.
bool mDirected;
} Edge;
typedef struct Graph
{
Graph(unsigned int aVertices, bool aDirected)
: mVertices(aVertices)
, mDirected(aDirected)
{
assert(mVertices);
}
~Graph() {}
void AddEdge(unsigned int aFrom, unsigned int aTo, int aWeight)
{
assert(aFrom < mVertices && aTo < mVertices);
mEdges.push_back(Edge(aFrom, aTo, aWeight, mDirected));
}
unsigned int mVertices; // number of vertices
std::vector<Edge> mEdges; // edges between vertices
bool mDirected; // Indicating this is directed graph or not.
} Graph;
#endif // GRAPH_H
#include "Kruskal.h"
#include "DisjointSet.h"
#include <algorithm> // for std::sort
#include <assert.h> // for assert
Kruskal::Kruskal(Graph aGraph)
: MST(aGraph)
{
Compute();
}
Kruskal::~Kruskal()
{
}
void SortEdge(std::vector<Edge>& aEdges)
{
std::sort(aEdges.begin(), aEdges.end());
}
void
Kruskal::Compute()
{
// We must have N - 1 edges at least to form the spanning tree,
// where N is the number of vertices.
assert(mGraph.mVertices && mGraph.mEdges.size() >= mGraph.mVertices - 1);
// Sort the edges in increasing order(from min to max).
SortEdge(mGraph.mEdges);
// Use to check whether two vertices are in the same group.
DisjointSet ds(mGraph.mVertices);
// To calculate the total weight of the spanning tree.
mCost = 0;
// Pick the smallest edge one by one from the sorted edges.
for (auto e: mGraph.mEdges) {
assert(!e.mDirected); // We only handle the undirected graph in Kruskal.
unsigned int u = e.mFrom;
unsigned int v = e.mTo;
// We only pick the edge whose vertices are in the different group.
// If the selected edge forms a cycle/loop with the spanning tree
// formed so far, then drop it.
if (ds.AreSameSets(u, v)) {
continue;
}
// Update the total weight so far and record the selected edge.
mCost += e.mWeight;
mEdges.push_back(e);
// Merge the u's group and v's group into one group.
ds.Union(u, v);
// It's finished if there is already N - 1 edges.
if (mEdges.size() == mGraph.mVertices - 1) {
break;
}
}
// Check we have exactly N - 1 edges to form the minimal spanning tree,
// where N is the number of vertices.
assert(mEdges.size() == mGraph.mVertices - 1);
}
#if !defined(KRUSKAL_H)
#define KRUSKAL_H
#include "MST.h"
// Implement the Kruskal's algorithm to find the minimal spanning tree
// in the graph
class Kruskal final: public MST
{
public:
Kruskal(Graph aGraph);
~Kruskal();
private:
// Find the minimal spanning tree in the graph.
void Compute();
};
#endif // KRUSKAL_H
CXX=g++
CFLAGS=-Wall -std=c++14
SOURCES=Kruskal.cpp\
Prim.cpp\
test.cpp
OBJECTS=$(SOURCES:.cpp=.o)
EXECUTABLE=run_test
all: $(EXECUTABLE)
$(EXECUTABLE): $(OBJECTS)
$(CXX) $(CFLAGS) $(OBJECTS) -o $(EXECUTABLE)
.cpp.o:
$(CXX) -c $(CFLAGS) $< -o $@
clean:
rm $(EXECUTABLE) *.o
#if !defined(MST_H)
#define MST_H
#include "Graph.h"
#include <vector> // for std::vector
// Base class for minimal spanning tree's implementation.
class MST
{
public:
explicit MST(Graph aGraph)
: mGraph(aGraph)
, mCost(0)
{
// We only handle the undirected graph here.
assert(!mGraph.mDirected);
}
virtual ~MST() = default;
// Returns the total weight of the minimal spanning tree.
unsigned int Weight() { return mCost; }
// Returns the edges forming the minimal spanning tree.
std::vector<Edge> Edges() { return mEdges; }
protected:
Graph mGraph;
unsigned int mCost; // The total weight of the minimal spanning tree.
std::vector<Edge> mEdges; // The edges to form the minimal spanning tree.
};
#endif // MST_H
#include "Prim.h"
#include <assert.h> // for assert
const int INF = ((unsigned int) ~0) >> 1; // Infinity: The max value of int.
Prim::Prim(Graph aGraph)
: MST(aGraph)
{
Compute();
}
Prim::~Prim()
{
}
void
Prim::ConvertGraph()
{
assert(mGraph.mVertices && !mGraph.mEdges.empty() && mAdjacentEdges.empty());
mAdjacentEdges.resize(mGraph.mVertices);
for (auto e: mGraph.mEdges) { // e: edge(u, v) = w
assert(!e.mDirected); // We only handle the undirected graph in Prim.
unsigned int u = e.mFrom;
unsigned int v = e.mTo;
// edge(u -> v) = w
mAdjacentEdges[u].push_back(Edge(u, v, e.mWeight, true));
// edge(v -> u) = w
mAdjacentEdges[v].push_back(Edge(v, u, e.mWeight, true));
}
}
void
Prim::Compute()
{
// We must have N - 1 edges at least to form the spanning tree,
// where N is the number of vertices.
assert(mGraph.mVertices &&
mGraph.mEdges.size() >= mGraph.mVertices - 1 &&
mQueue.empty());
ConvertGraph();
// Indicating whether the vertex is visited or not.
std::vector<bool> visited(mGraph.mVertices, false);
// Track the distances between unvisited vertices to the spanning tree.
std::vector<int> distances(mGraph.mVertices, INF);
// Take vertex 0 as the beginning vertex to form the spanning tree.
unsigned int start = 0;
distances[start] = 0;
mQueue.push(Edge(start, start, 0, true));
while (!mQueue.empty()) {
// Pick the vertex that has minimal distance to the spanning tree.
unsigned int nearest = mQueue.top().mTo;
// Suppose there are two edges connected to a vertex v:
// edge(u, v) = w, edge(u', v) = w' and w' < w.
// If edge(u, v) is pushed into queue before edge(u', v),
// then edge(u, v) should be skipped after edge(u', v) is popped.
if (visited[nearest]) {
mQueue.pop();
continue;
}
// Update the weight sum and record the picked edge that can form the MST.
if (nearest == start) {
// The weight sum should be 0 at first.
assert(!mQueue.top().mWeight && !mCost);
} else {
assert(!visited[nearest]);
visited[nearest] = true;
mCost += mQueue.top().mWeight;
assert(mQueue.top().mWeight == distances[nearest]);
mEdges.push_back(mQueue.top());
}
// It's finished if there are already N - 1 edges.
if (mEdges.size() == mGraph.mVertices - 1) {
break;
}
// Add to the nearest vertex to the spanning tree and update
// all the distances from the vertices to the tree formed so far.
mQueue.pop();
for (auto e: mAdjacentEdges[nearest]) {
assert(e.mFrom == nearest);
// Use !visited[e.mTo] to avoid re-pushing the edge that has been added.
if (!visited[e.mTo] && e.mWeight < distances[e.mTo]) {
distances[e.mTo] = e.mWeight;
// Use edge weight as the priority so the edge
// with minimal weight will be on the top of the priority queue.
mQueue.push(e);
}
}
}
}
// The algorithm was developed in 1930 by Czech mathematician Vojtěch Jarník
// and later rediscovered and republished by computer scientists Robert C. Prim in 1957
// and Edsger W. Dijkstra in 1959.
// Therefore, it is also sometimes called the DJP algorithm,
// Jarník's algorithm, the Prim–Jarník algorithm,
// or the Prim–Dijkstra algorithm.
#if !defined(PRIM_H)
#define PRIM_H
#include "MST.h"
#include <functional> // for std::greater
#include <queue> // for std::priority_queue
#include <vector> // for std::vector
// Implement the Prim's algorithm to find the minimal spanning tree
// in the input graph
class Prim final: public MST
{
public:
Prim(Graph aGraph);
~Prim();
private:
// Find the minimal spanning tree in the graph.
void Compute();
// Convert to the adjacent edges of all the vertice in the graph.
void ConvertGraph();
// The adjacent edge allows us to find the connected edges from one vertex
// in O(1). All the connected edges of vertex u will be stored in
// mAdjacentEdges[u].
//
// Undirected edge(u, v) = w will be stored as two directed edges:
// edge(u -> v) = w:
// adjacent[u][x].mFrom = u, adjacent[u][x].mTo = v,
// adjacent[u][x].mWeight = w,
// edge(v -> u) = w:
// adjacent[v][y].mFrom = v, adjacent[v][y].mTo = u,
// adjacent[v][y].mWeight = w,
// where x, y are the indices of vector adjacent[u] and adjacent[v].
std::vector< std::vector<Edge> > mAdjacentEdges;
// The priority queue allows us to get the vertex with minimal distance
// from mStart in O(1), and update the distance in O(log V).
// The std::priority_queue is in decreasing order by default,
// so we need to use std::greater to sort the elements in increasing order.
std::priority_queue<Edge,
std::vector<Edge>,
std::greater<Edge>> mQueue;
};
#endif // PRIM_H
#include "Graph.h"
#include "Kruskal.h"
#include "Prim.h"
#include <iostream>
void PrintMST(const char* aName, std::vector<Edge> aEdges, int aMST)
{
std::cout << aName << ": the weight of the MST is " << aMST << std::endl;
int sum = 0;
for (auto e: aEdges) {
sum += e.mWeight;
std::cout << e.mFrom << " -> " << e.mTo << " : " << e.mWeight << std::endl;
}
assert(sum == aMST);
}
int main()
{
// Graph:
// 8 7
// (1)----(2)----(3)
// 4 / | 2/ \ | \9
// / |11 / \ 14| \
// (0) | (8) \4 | (4)
// \ | 7/ \6 \ | /
// 8 \ | / \ \ | /10
// (7)----(6)----(5)
// 1 2
//
// MST:
// I.
// 7
// (1) (2)----(3)
// 4 / 2/ \ \9
// / / \ \
// (0) (8) \4 (4)
// \ \
// 8 \ \
// (7)----(6)----(5)
// 1 2
//
// II.
// 8 7
// (1)----(2)----(3)
// 4 / 2/ \ \9
// / / \ \
// (0) (8) \4 (4)
// \
// \
// (7)----(6)----(5)
// 1 2
Graph g(9, false);
g.AddEdge(0, 1, 4);
g.AddEdge(0, 7, 8);
g.AddEdge(1, 2, 8);
g.AddEdge(1, 7, 11);
g.AddEdge(2, 3, 7);
g.AddEdge(2, 8, 2);
g.AddEdge(2, 5, 4);
g.AddEdge(3, 4, 9);
g.AddEdge(3, 5, 14);
g.AddEdge(4, 5, 10);
g.AddEdge(5, 6, 2);
g.AddEdge(6, 7, 1);
g.AddEdge(6, 8, 6);
g.AddEdge(7, 8, 7);
Kruskal k(g);
PrintMST("Kruskal", k.Edges(), k.Weight());
Prim p(g);
PrintMST("Prim", p.Edges(), p.Weight());
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment