Skip to content

Instantly share code, notes, and snippets.

@igalshilman
Created February 7, 2017 15:58
Show Gist options
  • Save igalshilman/fba48dc5602860afb3cc270c318ffe06 to your computer and use it in GitHub Desktop.
Save igalshilman/fba48dc5602860afb3cc270c318ffe06 to your computer and use it in GitHub Desktop.
package com.igal
import java.lang.{Iterable => JavaIterable}
import org.apache.flink.api.common.functions.{GroupReduceFunction, Partitioner}
import org.apache.flink.api.common.operators.Order
import org.apache.flink.api.java.utils.ParameterTool
import org.apache.flink.api.scala.{DataSet, _}
import org.apache.flink.util.Collector
import scala.annotation.tailrec
import scala.collection.JavaConversions._
import scala.collection.mutable
object Reachability2 {
def expandGraph(edges: DataSet[(String, String)], n: Int): DataSet[(String, String)] = {
@tailrec
def loop(acc: List[DataSet[(String, String)]], n: Int): List[DataSet[(String, String)]] = {
if (n == 1) {
acc.reverse
} else {
val previousLayer = acc.head
val nextLayer = previousLayer.join(edges).where(1).equalTo(0)({
(l,r,out: Collector[(String, String)]) =>
val (x, y) = l
val (_, w) = r
if (x != w) {
val edge = (x, w)
out.collect(edge)
}
})
loop(nextLayer :: acc, n - 1)
}
}
val all = loop(List(edges), n)
all.reduce((a, b) => a.union(b))
}
def main(args: Array[String]) {
val parameters = ParameterTool.fromArgs(args)
val input = "./data/friends.tsv"
val env = ExecutionEnvironment.getExecutionEnvironment
val firstDegree = env.readCsvFile[(String, String)](input, fieldDelimiter = "\t")
.flatMap(e => List(e, e.swap))
val expanded = expandGraph(firstDegree, 2)
val result = expanded
.distinct()
.groupBy(0)
.reduceGroup(CollectUniqueNeighborNames)
.partitionByRange(0)
.sortPartition(0, Order.ASCENDING)
result.collect().foreach(println)
}
object CollectUniqueNeighborNames extends GroupReduceFunction[(String, String), (String, Iterable[String])] {
override def reduce(edges: JavaIterable[(String, String)],
out: Collector[(String, Iterable[String])]): Unit = {
val neighbors = new mutable.TreeSet[String]()
var source: String = ""
for ((u, v) <- edges) {
neighbors += v
source = u
}
val result: (String, Iterable[String]) = source -> neighbors
out.collect(result)
}
}
private def formatResults(user: (String, Iterable[String])): String = {
val (id, reachableFriends) = user
val withTabs = reachableFriends.mkString("\t")
s"$id\t$withTabs"
}
/**
* Store the result [[DataSet]] to a single file at @outputPath.
* Since it has to be a single file we have to shuffle the data into a single partition, then sort it.
*
* @param outputPath the file to save
* @param result the datset containing each user and a list of their reachable friends.
*/
private def storeLexicographicallySortedSingleFile(outputPath: String, result: DataSet[(String, Iterable[String])]): Unit = {
result.partitionCustom(new ConstantPartitioner[String], 0).setParallelism(1)
.sortPartition(0, Order.ASCENDING)
.map(r => formatResults(r))
.writeAsText(outputPath)
}
class ConstantPartitioner[T] extends Partitioner[T] {
override def partition(key: T, numPartitions: Int): Int = 0
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment