Last active
September 4, 2015 05:35
-
-
Save feynmanliang/18d6b6d55fce961f2f15 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java | |
index 6c02004..83d47c7 100644 | |
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java | |
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java | |
@@ -577,7 +577,9 @@ public String toString() { | |
StringBuilder build = new StringBuilder("["); | |
for (int i = 0; i < sizeInBytes; i += 8) { | |
build.append(java.lang.Long.toHexString(Platform.getLong(baseObject, baseOffset + i))); | |
- build.append(','); | |
+ if (i <= sizeInBytes-1) { | |
+ build.append(','); | |
+ } | |
} | |
build.append(']'); | |
return build.toString(); | |
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala | |
index bf03c61..6b4145c 100644 | |
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala | |
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala | |
@@ -17,14 +17,25 @@ | |
package org.apache.spark.sql | |
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream} | |
+ | |
+import scala.collection.mutable.ArrayBuffer | |
import scala.language.implicitConversions | |
import scala.reflect.runtime.universe.TypeTag | |
+import org.apache.spark.SparkEnv | |
+import org.apache.spark.annotation.Experimental | |
+import org.apache.spark.io.CompressionCodec | |
import org.apache.spark.rdd.RDD | |
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} | |
+import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, UnsafeProjection, UnsafeRow} | |
+import org.apache.spark.sql.execution.datasources.LogicalRelation | |
+import org.apache.spark.sql.sources.{BaseRelation, TableScan} | |
import org.apache.spark.sql.types._ | |
-import org.apache.spark.sql.catalyst.InternalRow | |
-import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow | |
-import org.apache.spark.sql.types.StructField | |
+import org.apache.spark.storage.StorageLevel | |
+import org.apache.spark.unsafe.Platform | |
+import org.apache.spark.unsafe.array.ByteArrayMethods | |
+import org.apache.spark.unsafe.memory.{MemoryBlock, TaskMemoryManager} | |
import org.apache.spark.unsafe.types.UTF8String | |
/** | |
@@ -110,4 +121,165 @@ private[sql] abstract class SQLImplicits { | |
DataFrameHolder( | |
_sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) | |
} | |
+ | |
+ /** | |
+ * ::Experimental:: | |
+ * | |
+ * Pimp my library decorator for tungsten caching of DataFrames. | |
+ * @since 1.5.1 | |
+ */ | |
+ @Experimental | |
+ implicit class TungstenCache(df: DataFrame) { | |
+ /** | |
+ * Packs the rows of [[df]] into contiguous blocks of memory. | |
+ * @param compressionType "" (default), "lz4", "lzf", or "snappy", see | |
+ * [[CompressionCodec.ALL_COMPRESSION_CODECS]] | |
+ * @param blockSize size of each MemoryBlock (default = 4 MB) | |
+ */ | |
+ def tungstenCache( | |
+ compressionType: String = "", blockSize: Int = 4000000): (RDD[_], DataFrame) = { | |
+ val schema = df.schema | |
+ | |
+ val convert = CatalystTypeConverters.createToCatalystConverter(schema) | |
+ val internalRows = df.rdd.map(convert(_).asInstanceOf[InternalRow]) | |
+ val cachedRDD = internalRows.mapPartitions { rowIterator => | |
+ val bufferedRowIterator = rowIterator.buffered | |
+ val convertToUnsafe = UnsafeProjection.create(schema) | |
+ val taskMemoryManager = new TaskMemoryManager(SparkEnv.get.executorMemoryManager) | |
+ val compressionCodec: Option[CompressionCodec] = if (compressionType.isEmpty) { | |
+ None | |
+ } else { | |
+ Some(CompressionCodec.createCodec(SparkEnv.get.conf, compressionType)) | |
+ } | |
+ | |
+ new Iterator[MemoryBlock] { | |
+ // NOTE: This assumes that size of every row < blockSize | |
+ def next(): MemoryBlock = { | |
+ // Packs rows into a `blockSize` bytes contiguous block of memory, starting a new block | |
+ // whenever the current fills up | |
+ // Each row is laid out in memory as [rowSize (4)|row (rowSize)] | |
+ val block = taskMemoryManager.allocateUnchecked(blockSize) | |
+ | |
+ var currOffset = 0 | |
+ while (bufferedRowIterator.hasNext && currOffset < blockSize) { | |
+ val currRow = convertToUnsafe.apply(bufferedRowIterator.head) | |
+ val recordSize = 4 + currRow.getSizeInBytes | |
+ if (currOffset + recordSize < blockSize) { | |
+ Platform.putInt( | |
+ block.getBaseObject, block.getBaseOffset + currOffset, currRow.getSizeInBytes) | |
+ currRow.writeToMemory(block.getBaseObject, block.getBaseOffset + currOffset + 4) | |
+ bufferedRowIterator.next() | |
+ } | |
+ currOffset += recordSize // Increment currOffset regardless to break loop when full | |
+ } | |
+ | |
+ // Optionally compress block before writing | |
+ compressionCodec match { | |
+ case Some(codec) => | |
+ // Compress the block using an on-heap byte array | |
+ val blockArray = new Array[Byte](blockSize) | |
+ Platform.copyMemory( | |
+ block.getBaseObject, | |
+ block.getBaseOffset, | |
+ blockArray, | |
+ Platform.BYTE_ARRAY_OFFSET, | |
+ blockSize) | |
+ val baos = new ByteArrayOutputStream(blockSize) | |
+ val compressedBaos = codec.compressedOutputStream(baos) | |
+ compressedBaos.write(blockArray) | |
+ compressedBaos.flush() | |
+ compressedBaos.close() | |
+ val compressedBlockArray = baos.toByteArray | |
+ | |
+ // Allocate a new block with compressed byte array padded to word boundary | |
+ val totalRecordSize = compressedBlockArray.size + 4 | |
+ val nearestWordBoundary = | |
+ ByteArrayMethods.roundNumberOfBytesToNearestWord(totalRecordSize) | |
+ val padding = nearestWordBoundary - totalRecordSize | |
+ val compressedBlock = taskMemoryManager.allocateUnchecked(totalRecordSize + padding) | |
+ Platform.putInt( | |
+ compressedBlock.getBaseObject, | |
+ compressedBlock.getBaseOffset, | |
+ padding) | |
+ Platform.copyMemory( | |
+ compressedBlockArray, | |
+ Platform.BYTE_ARRAY_OFFSET, | |
+ compressedBlock.getBaseObject, | |
+ compressedBlock.getBaseOffset + 4, | |
+ compressedBlockArray.size) | |
+ taskMemoryManager.freeUnchecked(block) | |
+ compressedBlock | |
+ case None => block | |
+ } | |
+ } | |
+ | |
+ def hasNext: Boolean = bufferedRowIterator.hasNext | |
+ } | |
+ }.setName(compressionType + "_" + df.toString).persist(StorageLevel.MEMORY_ONLY) | |
+ | |
+ val baseRelation: BaseRelation = new BaseRelation with TableScan { | |
+ override val sqlContext = _sqlContext | |
+ override val schema = df.schema | |
+ override val needConversion = false | |
+ | |
+ override def buildScan(): RDD[Row] = { | |
+ val numFields = this.schema.length | |
+ val _compressionType: String = compressionType | |
+ val _blockSize = blockSize | |
+ | |
+ cachedRDD.flatMap { rawBlock => | |
+ // Optionally decompress block | |
+ val compressionCodec: Option[CompressionCodec] = if (_compressionType.isEmpty) { | |
+ None | |
+ } else { | |
+ Some(CompressionCodec.createCodec(SparkEnv.get.conf, _compressionType)) | |
+ } | |
+ val block = compressionCodec match { | |
+ case Some(codec) => | |
+ // Copy compressed block (excluding padding) to on-heap byte array | |
+ val padding = Platform.getInt(rawBlock.getBaseObject, rawBlock.getBaseOffset) | |
+ val compressedBlockArray = new Array[Byte](_blockSize) | |
+ Platform.copyMemory( | |
+ rawBlock.getBaseObject, | |
+ rawBlock.getBaseOffset + 4, | |
+ compressedBlockArray, | |
+ Platform.BYTE_ARRAY_OFFSET, | |
+ rawBlock.size() - padding) | |
+ | |
+ // Decompress into MemoryBlock backed by on-heap byte array | |
+ val compressedBaos = new ByteArrayInputStream(compressedBlockArray) | |
+ val uncompressedBlockArray = new Array[Byte](_blockSize) | |
+ val cis = codec.compressedInputStream(compressedBaos) | |
+ cis.read(uncompressedBlockArray) | |
+ cis.close() | |
+ MemoryBlock.fromByteArray(uncompressedBlockArray) | |
+ case None => rawBlock | |
+ } | |
+ | |
+ val rows = new ArrayBuffer[InternalRow]() | |
+ var currOffset = 0 | |
+ var moreData = true | |
+ while (currOffset < block.size() && moreData) { | |
+ val rowSize = Platform.getInt(block.getBaseObject, block.getBaseOffset + currOffset) | |
+ currOffset += 4 | |
+ // TODO: should probably have a null terminator rather than relying on zeroed out | |
+ if (rowSize > 0) { | |
+ val unsafeRow = new UnsafeRow() | |
+ unsafeRow.pointTo( | |
+ block.getBaseObject, block.getBaseOffset + currOffset, numFields, rowSize) | |
+ rows.append(unsafeRow) | |
+ currOffset += rowSize | |
+ } else { | |
+ moreData = false | |
+ } | |
+ } | |
+ rows | |
+ }.asInstanceOf[RDD[Row]] | |
+ } | |
+ | |
+ override def toString: String = getClass.getSimpleName + s"[${df.toString}]" | |
+ } | |
+ (cachedRDD, DataFrame(_sqlContext, LogicalRelation(baseRelation))) | |
+ } | |
+ } | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala | |
index af7590c..2eade42 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala | |
@@ -26,7 +26,7 @@ import org.apache.spark.Accumulators | |
import org.apache.spark.sql.columnar._ | |
import org.apache.spark.sql.functions._ | |
import org.apache.spark.sql.test.SharedSQLContext | |
-import org.apache.spark.storage.{StorageLevel, RDDBlockId} | |
+import org.apache.spark.storage.{RDDBlockId, StorageLevel} | |
private case class BigData(s: String) | |
@@ -75,17 +75,17 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { | |
} | |
test("unpersist an uncached table will not raise exception") { | |
- assert(None == ctx.cacheManager.lookupCachedData(testData)) | |
+ assert(None === ctx.cacheManager.lookupCachedData(testData)) | |
testData.unpersist(blocking = true) | |
- assert(None == ctx.cacheManager.lookupCachedData(testData)) | |
+ assert(None === ctx.cacheManager.lookupCachedData(testData)) | |
testData.unpersist(blocking = false) | |
- assert(None == ctx.cacheManager.lookupCachedData(testData)) | |
+ assert(None === ctx.cacheManager.lookupCachedData(testData)) | |
testData.persist() | |
- assert(None != ctx.cacheManager.lookupCachedData(testData)) | |
+ assert(None !== ctx.cacheManager.lookupCachedData(testData)) | |
testData.unpersist(blocking = true) | |
- assert(None == ctx.cacheManager.lookupCachedData(testData)) | |
+ assert(None === ctx.cacheManager.lookupCachedData(testData)) | |
testData.unpersist(blocking = false) | |
- assert(None == ctx.cacheManager.lookupCachedData(testData)) | |
+ assert(None === ctx.cacheManager.lookupCachedData(testData)) | |
} | |
test("cache table as select") { | |
@@ -333,7 +333,33 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { | |
val accsSize = Accumulators.originals.size | |
ctx.uncacheTable("t1") | |
ctx.uncacheTable("t2") | |
- assert((accsSize - 2) == Accumulators.originals.size) | |
+ assert((accsSize - 2) === Accumulators.originals.size) | |
} | |
} | |
+ | |
+ test("tungsten cache uncompressed table and read") { | |
+ val data = testData | |
+ // Use a 0.4 KB block size to force multiple blocks | |
+ val (_, tungstenCachedDF) = data.tungstenCache("", 400) | |
+ assert(tungstenCachedDF.collect() === testData.collect()) | |
+ } | |
+ | |
+ test("tungsten cache lz4 compressed table and read") { | |
+ val data = testData | |
+ val (_, tungstenCachedDF) = data.tungstenCache("lz4", 400) | |
+ assert(tungstenCachedDF.collect() === testData.collect()) | |
+ } | |
+ | |
+ test("tungsten cache lzf compressed table and read") { | |
+ val data = testData | |
+ val (_, tungstenCachedDF) = data.tungstenCache("lzf", 400) | |
+ assert(tungstenCachedDF.collect() === testData.collect()) | |
+ } | |
+ | |
+ test("tungsten cache snappy compressed table and read") { | |
+ val data = testData | |
+ val (_, tungstenCachedDF) = data.tungstenCache("snappy", 400) | |
+ assert(tungstenCachedDF.collect() === testData.collect()) | |
+ } | |
+ | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala | |
index 2476b10..4f52535 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala | |
@@ -21,7 +21,7 @@ import java.io.ByteArrayOutputStream | |
import org.apache.spark.SparkFunSuite | |
import org.apache.spark.sql.catalyst.InternalRow | |
-import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} | |
+import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} | |
import org.apache.spark.sql.types._ | |
import org.apache.spark.unsafe.Platform | |
import org.apache.spark.unsafe.memory.MemoryAllocator | |
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java | |
index dd75820..3a51f0e 100644 | |
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java | |
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java | |
@@ -52,4 +52,11 @@ public long size() { | |
public static MemoryBlock fromLongArray(final long[] array) { | |
return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, array.length * 8); | |
} | |
+ | |
+ /** | |
+ * Creates a memory block pointing to the memory used by the byte array. | |
+ */ | |
+ public static MemoryBlock fromByteArray(final byte[] array) { | |
+ return new MemoryBlock(array, Platform.BYTE_ARRAY_OFFSET, array.length); | |
+ } | |
} | |
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java | |
index 97b2c93..8824f98 100644 | |
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java | |
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java | |
@@ -175,6 +175,16 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { | |
} | |
/** | |
+ * Allocates a contiguous block of memory, without checking for leaks provided by | |
+ * {@code allocatedNonPageMemory} | |
+ */ | |
+ public MemoryBlock allocateUnchecked(long size) throws OutOfMemoryError { | |
+ assert(size > 0) : "Size must be positive, but got " + size; | |
+ final MemoryBlock memory = executorMemoryManager.allocate(size); | |
+ return memory; | |
+ } | |
+ | |
+ /** | |
* Free memory allocated by {@link TaskMemoryManager#allocate(long)}. | |
*/ | |
public void free(MemoryBlock memory) { | |
@@ -187,6 +197,15 @@ public void free(MemoryBlock memory) { | |
} | |
/** | |
+ * Frees a contiguous block of memory, without checking for leaks provided by | |
+ * {@code allocatedNonPageMemory} | |
+ */ | |
+ public void freeUnchecked(MemoryBlock memory) { | |
+ assert (memory.pageNumber == -1) : "Should call freePage() for pages, not free()"; | |
+ executorMemoryManager.free(memory); | |
+ } | |
+ | |
+ /** | |
* Given a memory page and offset within that page, encode this address into a 64-bit long. | |
* This address will remain valid as long as the corresponding page has not been freed. | |
* |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment