Last active
April 27, 2016 22:12
-
-
Save blakewrege/f25fbec2b6873015d6bebec0014431b1 to your computer and use it in GitHub Desktop.
Parallel Prims Algorithm for Spark
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
50,70 | |
1,8,10 | |
2,17,10 | |
2,34,17 | |
3,35,16 | |
3,38,1 | |
4,5,6 | |
4,22,8 | |
4,25,3 | |
5,12,13 | |
5,19,2 | |
5,21,20 | |
5,50,8 | |
6,31,2 | |
6,32,8 | |
6,38,16 | |
7,20,1 | |
7,44,11 | |
8,20,3 | |
8,32,2 | |
8,41,6 | |
9,23,9 | |
9,35,10 | |
10,23,13 | |
10,27,12 | |
11,24,5 | |
11,26,14 | |
11,30,14 | |
11,37,15 | |
11,38,14 | |
13,34,5 | |
13,41,18 | |
14,23,15 | |
14,26,7 | |
14,44,9 | |
15,27,19 | |
15,33,20 | |
16,34,11 | |
16,46,13 | |
17,39,16 | |
17,47,9 | |
18,29,19 | |
18,32,14 | |
20,43,10 | |
21,34,1 | |
22,40,5 | |
23,28,20 | |
23,32,13 | |
23,46,6 | |
23,48,10 | |
24,36,9 | |
25,30,16 | |
25,32,17 | |
25,36,8 | |
25,42,2 | |
26,36,16 | |
29,49,15 | |
32,34,4 | |
32,35,11 | |
32,47,10 | |
33,35,17 | |
33,45,18 | |
34,49,16 | |
35,36,3 | |
37,46,6 | |
38,45,15 | |
38,49,10 | |
39,50,5 | |
40,45,19 | |
42,43,15 | |
44,50,16 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.log4j.Level | |
import org.apache.log4j.Logger | |
import org.apache.spark.SparkConf | |
import org.apache.spark.SparkContext | |
import org.apache.spark.graphx.Edge | |
import org.apache.spark.graphx.EdgeTriplet | |
import org.apache.spark.graphx.Graph | |
import org.apache.spark.graphx.Graph.graphToGraphOps | |
import org.apache.spark.graphx.VertexId | |
import org.apache.spark.rdd.RDD | |
import org.apache.spark.rdd.RDD.rddToPairRDDFunctions | |
object ParallelPrims { | |
Logger.getLogger("org").setLevel(Level.OFF) | |
Logger.getLogger("akka").setLevel(Level.OFF) | |
var total = 0 | |
def main(args: Array[String]) { | |
val conf = new SparkConf().setAppName("Parallel Prims") | |
val sc = new SparkContext(conf) | |
val logFile = "/root/cluster-computing/demos/graph-generator/NodeData.txt" | |
val logData = sc.textFile(logFile, 2).cache() | |
// Splitting off header node | |
val headerAndRows = logData.map(line => line.split(",").map(_.trim)) | |
val header = headerAndRows.first | |
val data = headerAndRows.filter(_(0) != header(0)) | |
// Parse number of Nodes and Edges from header | |
val numNodes = header(0).toInt | |
val numEdges = header(1).toInt | |
val vertexArray = new Array[(Long, String)](numNodes) | |
var edgeArray = new Array[Edge[Int]](numEdges) | |
// Create vertex array | |
var count = 0 | |
for (count <- 0 to numNodes - 1) { | |
vertexArray(count) = (count.toLong + 1, ("v" + (count + 1)).toString()) | |
} | |
count = 0 | |
val rrdarr = data.take(data.count.toInt) | |
// Create edge array | |
for (count <- 0 to (numEdges - 1)) { | |
val line = rrdarr(count) | |
val cols = line.toList | |
val edge = Edge(cols(0).toLong, cols(1).toLong, cols(2).toInt) | |
edgeArray(count) = Edge(cols(0).toLong, cols(1).toLong, cols(2).toInt) | |
} | |
// Creating graphx graph | |
val vertexRDD: RDD[(Long, (String))] = sc.parallelize(vertexArray) | |
val edgeRDD: RDD[Edge[Int]] = sc.parallelize(edgeArray) | |
var graph: Graph[String, Int] = Graph(vertexRDD, edgeRDD) | |
// graph.triplets.take(6).foreach(println) | |
// just empty RDD for MST | |
var MST = sc.parallelize(Array[EdgeTriplet[String, Int]]()) | |
// pick random vertex from graph | |
var Vt: RDD[VertexId] = sc.parallelize(Array(graph.pickRandomVertex)) | |
// do until all vertices is in Vt set | |
val vcount = graph.vertices.count | |
while (Vt.count < vcount) { | |
// rdd to make inner joins | |
val hVt = Vt.map(x => (x, x)) | |
// add key to make inner join | |
val bySrc = graph.triplets.map(triplet => (triplet.srcId, triplet)) | |
// add key to make inner join | |
val byDst = graph.triplets.map(triplet => (triplet.dstId, triplet)) | |
// all triplet where source vertex is in Vt | |
val bySrcJoined = bySrc.join(hVt).map(_._2._1) | |
// all triplet where destinaiton vertex is in Vt | |
val byDstJoined = byDst.join(hVt).map(_._2._1) | |
// sum previous two rdds and substract all triplets where both source and destination vertex in Vt | |
val candidates = bySrcJoined.union(byDstJoined).subtract(byDstJoined.intersection(bySrcJoined)) | |
// find triplet with least weight | |
val triplet = candidates.sortBy(triplet => triplet.attr).first | |
// add triplet to MST | |
MST = MST.union(sc.parallelize(Array(triplet))) | |
// find out whether we should add source or destinaiton vertex to Vt | |
if (!Vt.filter(x => x == triplet.srcId).isEmpty) { | |
Vt = Vt.union(sc.parallelize(Array(triplet.dstId))) | |
} else { | |
Vt = Vt.union(sc.parallelize(Array(triplet.srcId))) | |
} | |
} | |
// final minimum spanning tree | |
MST.collect.foreach { | |
p => | |
println(p.srcId + "<--->" + p.dstId + " " + (p.attr)) | |
} | |
val total = MST.map{case(a) => | |
a.attr.toDouble | |
}.collect | |
println(total.reduceLeft{ _ + _ }) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Currently having issues with the cluster doing this:
