Created
February 12, 2016 23:06
-
-
Save ankurdave/63acd24ef744aaac87e0 to your computer and use it in GitHub Desktop.
BFS in GraphX
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.spark.graphx._ | |
/** | |
* Returns the shortest directed-edge path from src to dst in the graph. If no path exists, returns | |
* the empty list. | |
*/ | |
def bfs[VD, ED](graph: Graph[VD, ED], src: VertexId, dst: VertexId): Seq[VertexId] = { | |
if (src == dst) return List(src) | |
// The attribute of each vertex is (dist from src, id of vertex with dist-1) | |
var g: Graph[(Int, VertexId), ED] = | |
graph.mapVertices((id, _) => (if (id == src) 0 else Int.MaxValue, 0L)).cache() | |
// Traverse forward from src | |
var dstAttr = (Int.MaxValue, 0L) | |
while (dstAttr._1 == Int.MaxValue) { | |
val msgs = g.aggregateMessages[(Int, VertexId)]( | |
e => if (e.srcAttr._1 != Int.MaxValue && e.srcAttr._1 + 1 < e.dstAttr._1) { | |
e.sendToDst((e.srcAttr._1 + 1, e.srcId)) | |
}, | |
(a, b) => if (a._1 < b._1) a else b).cache() | |
if (msgs.count == 0) return List.empty | |
g = g.ops.joinVertices(msgs) { | |
(id, oldAttr, newAttr) => | |
if (newAttr._1 < oldAttr._1) newAttr else oldAttr | |
}.cache() | |
dstAttr = g.vertices.filter(_._1 == dst).first()._2 | |
} | |
// Traverse backward from dst and collect the path | |
var path: List[VertexId] = dstAttr._2 :: dst :: Nil | |
while (path.head != src) { | |
path = g.vertices.filter(_._1 == path.head).first()._2._2 :: path | |
} | |
path | |
} | |
val g = Graph.fromEdgeTuples( | |
sc.parallelize(List((1L, 2L), (2L, 3L), (3L, 4L), (2L, 4L), (10L, 11L))), | |
defaultValue = 1) | |
bfs(g, 1L, 4L) | |
bfs(g, 1L, 10L) | |
bfs(g, 4L, 3L) | |
bfs(g, 3L, 3L) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment