Last active
May 8, 2020 09:59
-
-
Save MishaelRosenthal/108ebbbb7590c7d3104b to your computer and use it in GitHub Desktop.
RDD group by small number of groups
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| package core.sparkTest.utils | |
| import java.io._ | |
| import java.nio.file.Files | |
| import core.Pimps._ | |
| import org.apache.hadoop.io.compress.CompressionCodec | |
| import org.apache.hadoop.io.{BytesWritable, NullWritable} | |
| import org.apache.hadoop.mapred.JobConf | |
| import org.apache.hadoop.mapred.lib.MultipleSequenceFileOutputFormat | |
| import org.apache.spark.rdd.RDD | |
| import org.apache.spark.{HashPartitioner, SparkContext} | |
| import org.slf4j.LoggerFactory | |
| import scala.reflect.ClassTag | |
| /** | |
| * Created by mishael on 4/13/15. | |
| * | |
| */ | |
| class GroupByKeySmallNumberOfGroups[K: ClassTag, V: ClassTag](val pairRdd: RDD[(K, V)]) { | |
| import GroupByKeySmallNumberOfGroups._ | |
| val logger = LoggerFactory.getLogger(getClass) | |
| def groupByKey(partitions: Int, bufferSize: Int, | |
| codec: Option[Class[_ <: CompressionCodec]]): Map[K, RDD[V]] = { | |
| if(!pairRdd.context.isLocal) | |
| throw new NotImplementedError("Non local mode not supported yet.") | |
| val tempDir = Files.createTempDirectory(tempDirPrefix).toFile | |
| val groupsDirPath = tempDir.getAbsolutePath / groupsDirName | |
| logger.info(s"Group by temp files will be written to: $groupsDirPath") | |
| val partitioned = pairRdd.partitionBy(new HashPartitioner(partitions)) | |
| GroupByKeySmallNumberOfGroups.saveAsMultiSequenceFile(partitioned, groupsDirPath, bufferSize, codec) | |
| val rddSeq = for { | |
| key <- partitioned.keys.collect() | |
| filePath = groupsDirPath / fileNameForKey(key) | |
| rdd = pairRdd.sparkContext.multipleSequenceFile[V](filePath) | |
| } yield (key, rdd) | |
| rddSeq.toMap | |
| } | |
| } | |
| object GroupByKeySmallNumberOfGroups { | |
| val logger = LoggerFactory.getLogger(getClass) | |
| val keyPrefix = "key_" | |
| val tempDirPrefix = "groupBySmallNumOfGroups" | |
| val groupsDirName = "groups" | |
| implicit class PimpSparkContextPairRDDGroupBy(val sc: SparkContext) extends AnyVal{ | |
| def multipleSequenceFile[V: ClassTag](path: String, minPartitions: Int = sc.defaultMinPartitions) = { | |
| for { | |
| (_, valueWritable) <- sc.sequenceFile(path, classOf[NullWritable], classOf[BytesWritable], minPartitions) | |
| value <- GroupByKeySmallNumberOfGroups.deserialize[Array[V]](valueWritable.getBytes) | |
| } yield value | |
| } | |
| } | |
| class RDDMultipleSequenceOutputFormat[K] extends MultipleSequenceFileOutputFormat[Any, BytesWritable] { | |
| override def generateFileNameForKeyValue(key: Any, value: BytesWritable, name: String): String = | |
| fileNameForKey(key.asInstanceOf[K]) | |
| override def generateActualKey(key: Any, value: BytesWritable): Any = NullWritable.get() | |
| } | |
| def fileNameForKey[K](key: K) = s"$keyPrefix${key.toString}_${key.hashCode}" | |
| def classLoader = Option(Thread.currentThread().getContextClassLoader).getOrElse(this.getClass.getClassLoader) | |
| def saveAsMultiSequenceFile[K: ClassTag, V: ClassTag](pairRdd: RDD[(K, V)], path: String, bufferSize: Int = 100, | |
| codec: Option[Class[_ <: CompressionCodec]] = None) = { | |
| val writableRdd = pairRdd.mapPartitions{ | |
| _.grouped(bufferSize).flatMap(_.groupBy(_._1).map{case (key, values) => (key, values.map(_._2).toArray)}) | |
| .map{case (key, valueArr) => (key, new BytesWritable(serialize(valueArr)))} | |
| //.map{case (key, valueArr) => (new BytesWritable(serialize(key)), new BytesWritable(serialize(valueArr)))} | |
| } | |
| val format = classOf[RDDMultipleSequenceOutputFormat[K]] | |
| val jobConf = new JobConf(writableRdd.context.hadoopConfiguration) | |
| writableRdd.saveAsHadoopFile(path, classOf[NullWritable], classOf[BytesWritable], format, jobConf, codec) | |
| } | |
| def serialize[T](o: T): Array[Byte] = { | |
| val bos = new ByteArrayOutputStream() | |
| val oos = new ObjectOutputStream(bos) | |
| oos.writeObject(o) | |
| oos.close() | |
| bos.toByteArray | |
| } | |
| def deserialize[T](bytes: Array[Byte], loader: ClassLoader = classLoader): T = { | |
| val bis = new ByteArrayInputStream(bytes) | |
| val ois = new ObjectInputStream(bis) { | |
| override def resolveClass(desc: ObjectStreamClass) = | |
| Class.forName(desc.getName, false, loader) | |
| } | |
| ois.readObject.asInstanceOf[T] | |
| } | |
| } | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| package core.sparkTest.utils | |
| import org.apache.hadoop.io.compress.CompressionCodec | |
| import org.apache.spark.rdd.RDD | |
| import org.apache.spark.{SparkConf, SparkContext} | |
| import scala.reflect.ClassTag | |
| /** | |
| * Created by mishael on 4/13/15. | |
| * | |
| */ | |
| object GroupBySmallNumOfGroupsExample extends App{ | |
| implicit class PimpPairRDD[K: ClassTag, V: ClassTag](val pairRdd: RDD[(K, V)]){ | |
| /** | |
| * Assumes keys can be differentiated by the outputs of their toString(). | |
| * Note: writes to a temp directory. | |
| */ | |
| def groupByKeySmallNumOfGroups(partitions: Int = 10, bufferSize: Int = 100, | |
| codec: Option[Class[_ <: CompressionCodec]] = None): Map[K, RDD[V]] = | |
| new GroupByKeySmallNumberOfGroups[K, V](pairRdd).groupByKey(partitions, bufferSize, codec) | |
| } | |
| val conf = new SparkConf() | |
| .setMaster("local[*]") | |
| .setAppName("Simple Application") | |
| implicit val sc = new SparkContext(conf) | |
| case class Person(name: String) | |
| val names = List("David", "Miriam", "Rachel") | |
| val rdd = sc.parallelize(1 to 100).map(i => Person(names(i % 3))) | |
| val keyed = rdd.keyBy(identity) | |
| //https://github.com/fullcontact/hadoop-sstable/issues/11 | |
| val res = keyed.groupByKeySmallNumOfGroups() | |
| println(s"result:\n${res.mapValues(_.collect().toList).mkString("\n")}") | |
| } | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment