Last active
September 3, 2015 23:36
-
-
Save feynmanliang/ca674a37574fb625011b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java | |
index 6c02004..83d47c7 100644 | |
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java | |
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java | |
@@ -577,7 +577,9 @@ public String toString() { | |
StringBuilder build = new StringBuilder("["); | |
for (int i = 0; i < sizeInBytes; i += 8) { | |
build.append(java.lang.Long.toHexString(Platform.getLong(baseObject, baseOffset + i))); | |
- build.append(','); | |
+ if (i <= sizeInBytes-1) { | |
+ build.append(','); | |
+ } | |
} | |
build.append(']'); | |
return build.toString(); | |
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala | |
index bf03c61..463de71 100644 | |
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala | |
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala | |
@@ -17,14 +17,21 @@ | |
package org.apache.spark.sql | |
+import scala.collection.mutable.ArrayBuffer | |
import scala.language.implicitConversions | |
import scala.reflect.runtime.universe.TypeTag | |
+import org.apache.spark.SparkEnv | |
+import org.apache.spark.annotation.Experimental | |
import org.apache.spark.rdd.RDD | |
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} | |
+import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, UnsafeProjection, UnsafeRow} | |
+import org.apache.spark.sql.execution.datasources.LogicalRelation | |
+import org.apache.spark.sql.sources.{BaseRelation, TableScan} | |
import org.apache.spark.sql.types._ | |
-import org.apache.spark.sql.catalyst.InternalRow | |
-import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow | |
-import org.apache.spark.sql.types.StructField | |
+import org.apache.spark.storage.StorageLevel | |
+import org.apache.spark.unsafe.Platform | |
+import org.apache.spark.unsafe.memory.{MemoryBlock, TaskMemoryManager} | |
import org.apache.spark.unsafe.types.UTF8String | |
/** | |
@@ -110,4 +117,89 @@ private[sql] abstract class SQLImplicits { | |
DataFrameHolder( | |
_sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) | |
} | |
+ | |
+ /** | |
+ * ::Experimental:: | |
+ * | |
+ * Pimp my library decorator for tungsten caching of DataFrames. | |
+ * @since 1.5.1 | |
+ */ | |
+ @Experimental | |
+ implicit class TungstenCache(df: DataFrame) { | |
+ /** | |
+ * Packs the rows of [[df]] into contiguous blocks of memory. | |
+ */ | |
+ def tungstenCache(): (RDD[_], DataFrame) = { | |
+ val BLOCK_SIZE = 4000000 // 4 MB blocks | |
+ val schema = df.schema | |
+ | |
+ val convert = CatalystTypeConverters.createToCatalystConverter(schema) | |
+ val internalRows = df.rdd.map(convert(_).asInstanceOf[InternalRow]) | |
+ val cachedRDD = internalRows.mapPartitions { rowIterator => | |
+ val bufferedRowIterator = rowIterator.buffered | |
+ | |
+ val convertToUnsafe = UnsafeProjection.create(schema) | |
+ val taskMemoryManager = new TaskMemoryManager(SparkEnv.get.executorMemoryManager) | |
+ new Iterator[MemoryBlock] { | |
+ | |
+ // This assumes that size of row < BLOCK_SIZE | |
+ def next(): MemoryBlock = { | |
+ val block = taskMemoryManager.allocateUnchecked(BLOCK_SIZE) | |
+ var currOffset = 0 | |
+ | |
+ while (bufferedRowIterator.hasNext && currOffset < BLOCK_SIZE) { | |
+ val currRow = convertToUnsafe.apply(bufferedRowIterator.head) | |
+ val recordSize = 4 + currRow.getSizeInBytes | |
+ | |
+ if (currOffset + recordSize < BLOCK_SIZE) { | |
+ // Pack into memory with layout [rowSize (4) | row (rowSize)] | |
+ Platform.putInt( | |
+ block.getBaseObject, block.getBaseOffset + currOffset, currRow.getSizeInBytes) | |
+ currRow.writeToMemory( | |
+ block.getBaseObject, block.getBaseOffset + currOffset + 4) | |
+ bufferedRowIterator.next() | |
+ } | |
+ currOffset += recordSize // Increment regardless to break loop when full | |
+ } | |
+ block | |
+ } | |
+ | |
+ def hasNext: Boolean = bufferedRowIterator.hasNext | |
+ } | |
+ }.persist(StorageLevel.MEMORY_ONLY) | |
+ | |
+ val baseRelation: BaseRelation = new BaseRelation with TableScan { | |
+ override val sqlContext = _sqlContext | |
+ override val schema = df.schema | |
+ override val needConversion = false | |
+ | |
+ override def buildScan(): RDD[Row] = { | |
+ val numFields = this.schema.length | |
+ | |
+ cachedRDD.flatMap { block => | |
+ val rows = new ArrayBuffer[InternalRow]() | |
+ var currOffset = 0 | |
+ var moreData = true | |
+ while (currOffset < block.size() && moreData) { | |
+ val rowSize = Platform.getInt(block.getBaseObject, block.getBaseOffset + currOffset) | |
+ currOffset += 4 | |
+ if (rowSize > 0) { | |
+ val unsafeRow = new UnsafeRow() | |
+ unsafeRow.pointTo( | |
+ block.getBaseObject, block.getBaseOffset + currOffset, numFields, rowSize) | |
+ rows.append(unsafeRow) | |
+ currOffset += rowSize | |
+ } else { | |
+ moreData = false | |
+ } | |
+ } | |
+ rows | |
+ }.asInstanceOf[RDD[Row]] | |
+ } | |
+ | |
+ override def toString: String = getClass.getSimpleName + s"[${df.toString}]" | |
+ } | |
+ (cachedRDD, DataFrame(_sqlContext, LogicalRelation(baseRelation))) | |
+ } | |
+ } | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala | |
index af7590c..9036df2 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala | |
@@ -26,7 +26,7 @@ import org.apache.spark.Accumulators | |
import org.apache.spark.sql.columnar._ | |
import org.apache.spark.sql.functions._ | |
import org.apache.spark.sql.test.SharedSQLContext | |
-import org.apache.spark.storage.{StorageLevel, RDDBlockId} | |
+import org.apache.spark.storage.{RDDBlockId, StorageLevel} | |
private case class BigData(s: String) | |
@@ -75,17 +75,17 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { | |
} | |
test("unpersist an uncached table will not raise exception") { | |
- assert(None == ctx.cacheManager.lookupCachedData(testData)) | |
+ assert(None === ctx.cacheManager.lookupCachedData(testData)) | |
testData.unpersist(blocking = true) | |
- assert(None == ctx.cacheManager.lookupCachedData(testData)) | |
+ assert(None === ctx.cacheManager.lookupCachedData(testData)) | |
testData.unpersist(blocking = false) | |
- assert(None == ctx.cacheManager.lookupCachedData(testData)) | |
+ assert(None === ctx.cacheManager.lookupCachedData(testData)) | |
testData.persist() | |
- assert(None != ctx.cacheManager.lookupCachedData(testData)) | |
+ assert(None !== ctx.cacheManager.lookupCachedData(testData)) | |
testData.unpersist(blocking = true) | |
- assert(None == ctx.cacheManager.lookupCachedData(testData)) | |
+ assert(None === ctx.cacheManager.lookupCachedData(testData)) | |
testData.unpersist(blocking = false) | |
- assert(None == ctx.cacheManager.lookupCachedData(testData)) | |
+ assert(None === ctx.cacheManager.lookupCachedData(testData)) | |
} | |
test("cache table as select") { | |
@@ -333,7 +333,13 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { | |
val accsSize = Accumulators.originals.size | |
ctx.uncacheTable("t1") | |
ctx.uncacheTable("t2") | |
- assert((accsSize - 2) == Accumulators.originals.size) | |
+ assert((accsSize - 2) === Accumulators.originals.size) | |
} | |
} | |
+ | |
+ test("tungsten cache table and read") { | |
+ val data = testData | |
+ val (cachedRDD, tungstenCachedDF) = data.tungstenCache() | |
+ assert(tungstenCachedDF.collect() === testData.collect()) | |
+ } | |
} | |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala | |
index 2476b10..4f52535 100644 | |
--- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala | |
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala | |
@@ -21,7 +21,7 @@ import java.io.ByteArrayOutputStream | |
import org.apache.spark.SparkFunSuite | |
import org.apache.spark.sql.catalyst.InternalRow | |
-import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} | |
+import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} | |
import org.apache.spark.sql.types._ | |
import org.apache.spark.unsafe.Platform | |
import org.apache.spark.unsafe.memory.MemoryAllocator | |
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java | |
index 97b2c93..cc78fc8 100644 | |
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java | |
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java | |
@@ -175,6 +175,16 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { | |
} | |
/** | |
+ * Allocates a contiguous block of memory, without checking for leaks provided by | |
+ * {@code allocatedNonPageMemory} | |
+ */ | |
+ public MemoryBlock allocateUnchecked(long size) throws OutOfMemoryError { | |
+ assert(size > 0) : "Size must be positive, but got " + size; | |
+ final MemoryBlock memory = executorMemoryManager.allocate(size); | |
+ return memory; | |
+ } | |
+ | |
+ /** | |
* Free memory allocated by {@link TaskMemoryManager#allocate(long)}. | |
*/ | |
public void free(MemoryBlock memory) { |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment