Last active
August 29, 2015 14:17
-
-
Save squito/c2d1dd5413a60830d6f3 to your computer and use it in GitHub Desktop.
GroupedRDD
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.io.{IOException, ObjectOutputStream} | |
import scala.language.existentials | |
import scala.reflect.ClassTag | |
import org.apache.spark._ | |
import org.apache.spark.rdd.RDD | |
case class GroupedRDDPartition( | |
index: Int, | |
@transient rdd: RDD[_], | |
parentsIndices: Array[Int]) extends Partition { | |
var parents: Seq[Partition] = parentsIndices.map(rdd.partitions(_)) | |
@throws(classOf[IOException]) | |
private def writeObject(oos: ObjectOutputStream): Unit = { | |
// Update the reference to parent partition at the time of task serialization | |
parents = parentsIndices.map(rdd.partitions(_)) | |
oos.defaultWriteObject() | |
} | |
} | |
class GroupedRDD[T: ClassTag]( | |
@transient var prev: RDD[T], | |
partsPerGroup: Int) | |
extends RDD[T](prev.context, Nil) { // Nil since we implement getDependencies | |
override def getPartitions: Array[Partition] = { | |
//take the partitions of the previous RDD, and put them together into groups of size partsPerGroup | |
(0 until prev.getPartitions.length).grouped(partsPerGroup).zipWithIndex.map { case (parentParts, idx) => | |
new GroupedRDDPartition(idx, prev, parentParts.toArray) | |
}.toArray | |
} | |
override def compute(partition: Partition, context: TaskContext): Iterator[T] = { | |
partition.asInstanceOf[GroupedRDDPartition].parents.iterator.flatMap { parentPartition => | |
firstParent[T].iterator(parentPartition, context) | |
} | |
} | |
override def getDependencies: Seq[Dependency[_]] = { | |
Seq(new NarrowDependency(prev) { | |
def getParents(id: Int): Seq[Int] = | |
partitions(id).asInstanceOf[GroupedRDDPartition].parentsIndices | |
}) | |
} | |
override def clearDependencies() { | |
super.clearDependencies() | |
prev = null | |
} | |
override def getPreferredLocations(partition: Partition): Seq[String] = { | |
// TODO you could do something much fancier to figure out the preferred location, eg. take the | |
// most common preferred locations, or something. But we'll be dumb -- take the preferred locations | |
// of the first parent | |
val p = partition.asInstanceOf[GroupedRDDPartition] | |
val firstParent = prev.getPartitions(p.parentsIndices(0)) | |
prev.getPreferredLocations(firstParent) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment