Skip to content

Instantly share code, notes, and snippets.

@squito
Last active August 29, 2015 14:17
Show Gist options
  • Save squito/c2d1dd5413a60830d6f3 to your computer and use it in GitHub Desktop.
Save squito/c2d1dd5413a60830d6f3 to your computer and use it in GitHub Desktop.
GroupedRDD
import java.io.{IOException, ObjectOutputStream}
import scala.language.existentials
import scala.reflect.ClassTag
import org.apache.spark._
import org.apache.spark.rdd.RDD
case class GroupedRDDPartition(
index: Int,
@transient rdd: RDD[_],
parentsIndices: Array[Int]) extends Partition {
var parents: Seq[Partition] = parentsIndices.map(rdd.partitions(_))
@throws(classOf[IOException])
private def writeObject(oos: ObjectOutputStream): Unit = {
// Update the reference to parent partition at the time of task serialization
parents = parentsIndices.map(rdd.partitions(_))
oos.defaultWriteObject()
}
}
class GroupedRDD[T: ClassTag](
@transient var prev: RDD[T],
partsPerGroup: Int)
extends RDD[T](prev.context, Nil) { // Nil since we implement getDependencies
override def getPartitions: Array[Partition] = {
//take the partitions of the previous RDD, and put them together into groups of size partsPerGroup
(0 until prev.getPartitions.length).grouped(partsPerGroup).zipWithIndex.map { case (parentParts, idx) =>
new GroupedRDDPartition(idx, prev, parentParts.toArray)
}.toArray
}
override def compute(partition: Partition, context: TaskContext): Iterator[T] = {
partition.asInstanceOf[GroupedRDDPartition].parents.iterator.flatMap { parentPartition =>
firstParent[T].iterator(parentPartition, context)
}
}
override def getDependencies: Seq[Dependency[_]] = {
Seq(new NarrowDependency(prev) {
def getParents(id: Int): Seq[Int] =
partitions(id).asInstanceOf[GroupedRDDPartition].parentsIndices
})
}
override def clearDependencies() {
super.clearDependencies()
prev = null
}
override def getPreferredLocations(partition: Partition): Seq[String] = {
// TODO you could do something much fancier to figure out the preferred location, eg. take the
// most common preferred locations, or something. But we'll be dumb -- take the preferred locations
// of the first parent
val p = partition.asInstanceOf[GroupedRDDPartition]
val firstParent = prev.getPartitions(p.parentsIndices(0))
prev.getPreferredLocations(firstParent)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment