Last active
January 4, 2020 14:36
-
-
Save kisimple/63e72eb3bf516cc5b254fb4ac1ad5dc6 to your computer and use it in GitHub Desktop.
GeneratedClass for Partial aggregate
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//// SELECT count(*), avg(age), gender FROM people WHERE age > 18 GROUP BY gender LIMIT 20 | |
package org.apache.spark.sql.catalyst.expressions; | |
import org.apache.spark.sql.catalyst.InternalRow; | |
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; | |
import org.apache.spark.sql.catalyst.expressions.UnsafeRow; | |
import org.apache.spark.unsafe.Platform; | |
import org.apache.spark.unsafe.types.UTF8String; | |
public class GeneratedClass extends | |
org.apache.spark.sql.catalyst.expressions.codegen.GeneratedClass { | |
public Object generate(Object[] references) { | |
return new GeneratedIterator(references); | |
} | |
final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { | |
private Object[] references; | |
private scala.collection.Iterator[] inputs; | |
private boolean agg_initAgg; | |
private boolean agg_bufIsNull; | |
private long agg_bufValue; | |
private boolean agg_bufIsNull1; | |
private double agg_bufValue1; | |
private boolean agg_bufIsNull2; | |
private long agg_bufValue2; | |
private org.apache.spark.sql.execution.aggregate.HashAggregateExec agg_plan; | |
private agg_FastHashMap agg_fastHashMap; | |
private org.apache.spark.unsafe.KVIterator agg_fastHashMapIter; | |
private org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap agg_hashMap; | |
private org.apache.spark.sql.execution.UnsafeKVExternalSorter agg_sorter; | |
private org.apache.spark.unsafe.KVIterator agg_mapIter; | |
private org.apache.spark.sql.execution.metric.SQLMetric agg_peakMemory; | |
private org.apache.spark.sql.execution.metric.SQLMetric agg_spillSize; | |
private org.apache.spark.sql.execution.metric.SQLMetric scan_numOutputRows; | |
private scala.collection.Iterator scan_input; | |
private org.apache.spark.sql.execution.metric.SQLMetric filter_numOutputRows; | |
private UnsafeRow filter_result; | |
private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder filter_holder; | |
private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter filter_rowWriter; | |
private UnsafeRow project_result; | |
private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder project_holder; | |
private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter project_rowWriter; | |
private UnsafeRow agg_result1; | |
private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder; | |
private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter; | |
private int agg_value9; | |
private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowJoiner agg_unsafeRowJoiner; | |
private org.apache.spark.sql.execution.metric.SQLMetric wholestagecodegen_numOutputRows; | |
private org.apache.spark.sql.execution.metric.SQLMetric wholestagecodegen_aggTime; | |
public GeneratedIterator(Object[] references) { | |
this.references = references; | |
} | |
public void init(int index, scala.collection.Iterator[] inputs) { | |
partitionIndex = index; | |
this.inputs = inputs; | |
wholestagecodegen_init_0(); | |
wholestagecodegen_init_1(); | |
} | |
//// 初始化代码是如何生成的? TODO | |
private void wholestagecodegen_init_0() { | |
agg_initAgg = false; | |
this.agg_plan = (org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0]; | |
agg_fastHashMap = new agg_FastHashMap(agg_plan.getTaskMemoryManager(), agg_plan.getEmptyAggregationBuffer()); | |
this.agg_peakMemory = (org.apache.spark.sql.execution.metric.SQLMetric) references[1]; | |
this.agg_spillSize = (org.apache.spark.sql.execution.metric.SQLMetric) references[2]; | |
this.scan_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[3]; | |
scan_input = inputs[0]; | |
this.filter_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[4]; | |
filter_result = new UnsafeRow(2); | |
this.filter_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(filter_result, 32); | |
this.filter_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(filter_holder, 2); | |
project_result = new UnsafeRow(2); | |
this.project_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(project_result, 32); | |
this.project_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(project_holder, 2); | |
} | |
//// 由 HashAggregateExec 添加的类 | |
//// 调用 HashMapGenerator#generate 生成 | |
public class agg_FastHashMap { | |
private org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch batch; | |
private int[] buckets; | |
private int capacity = 1 << 16; | |
private double loadFactor = 0.5; | |
private int numBuckets = (int) (capacity / loadFactor); | |
private int maxSteps = 2; | |
private int numRows = 0; | |
private org.apache.spark.sql.types.StructType keySchema = new org.apache.spark.sql.types.StructType() | |
.add("gender", org.apache.spark.sql.types.DataTypes.StringType); | |
private org.apache.spark.sql.types.StructType valueSchema = new org.apache.spark.sql.types.StructType() | |
.add("count", org.apache.spark.sql.types.DataTypes.LongType) | |
.add("sum", org.apache.spark.sql.types.DataTypes.DoubleType) | |
.add("count", org.apache.spark.sql.types.DataTypes.LongType); | |
private Object emptyVBase; | |
private long emptyVOff; | |
private int emptyVLen; | |
private boolean isBatchFull = false; | |
public agg_FastHashMap( | |
org.apache.spark.memory.TaskMemoryManager taskMemoryManager, | |
InternalRow emptyAggregationBuffer) { | |
batch = org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch | |
.allocate(keySchema, valueSchema, taskMemoryManager, capacity); | |
final UnsafeProjection valueProjection = UnsafeProjection.create(valueSchema); | |
final byte[] emptyBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes(); | |
emptyVBase = emptyBuffer; | |
emptyVOff = Platform.BYTE_ARRAY_OFFSET; | |
emptyVLen = emptyBuffer.length; | |
buckets = new int[numBuckets]; | |
java.util.Arrays.fill(buckets, -1); | |
} | |
public org.apache.spark.sql.catalyst.expressions.UnsafeRow findOrInsert(UTF8String agg_key) { | |
long h = hash(agg_key); | |
int step = 0; | |
int idx = (int) h & (numBuckets - 1); | |
while (step < maxSteps) { | |
// Return bucket index if it's either an empty slot or already contains the key | |
if (buckets[idx] == -1) { | |
if (numRows < capacity && !isBatchFull) { | |
// creating the unsafe for new entry | |
UnsafeRow agg_result = new UnsafeRow(1); | |
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder | |
= new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, 32); | |
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter | |
= new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder, 1); | |
agg_holder.reset(); //TODO: investigate if reset or zeroout are actually needed | |
agg_rowWriter.zeroOutNullBytes(); | |
agg_rowWriter.write(0, agg_key); | |
agg_result.setTotalSize(agg_holder.totalSize()); | |
Object kbase = agg_result.getBaseObject(); | |
long koff = agg_result.getBaseOffset(); | |
int klen = agg_result.getSizeInBytes(); | |
UnsafeRow vRow | |
= batch.appendRow(kbase, koff, klen, emptyVBase, emptyVOff, emptyVLen); | |
if (vRow == null) { | |
isBatchFull = true; | |
} else { | |
buckets[idx] = numRows++; | |
} | |
return vRow; | |
} else { | |
// No more space | |
return null; | |
} | |
} else if (equals(idx, agg_key)) { | |
return batch.getValueRow(buckets[idx]); | |
} | |
idx = (idx + 1) & (numBuckets - 1); | |
step++; | |
} | |
// Didn't find it | |
return null; | |
} | |
private boolean equals(int idx, UTF8String agg_key) { | |
UnsafeRow row = batch.getKeyRow(buckets[idx]); | |
return (row.getUTF8String(0).equals(agg_key)); | |
} | |
private long hash(UTF8String agg_key) { | |
long agg_hash = 0; | |
int agg_result = 0; | |
byte[] agg_bytes = agg_key.getBytes(); | |
for (int i = 0; i < agg_bytes.length; i++) { | |
int agg_hash1 = agg_bytes[i]; | |
agg_result = (agg_result ^ (0x9e3779b9)) + agg_hash1 + (agg_result << 6) + (agg_result >>> 2); | |
} | |
agg_hash = (agg_hash ^ (0x9e3779b9)) + agg_result + (agg_hash << 6) + (agg_hash >>> 2); | |
return agg_hash; | |
} | |
public org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> rowIterator() { | |
return batch.rowIterator(); | |
} | |
public void close() { | |
batch.close(); | |
} | |
} | |
//// 由 HashAggregateExec 添加的函数 | |
//// 会调用 CodegenSupport#produce 来生成代码 | |
private void agg_doAggregateWithKeys() throws java.io.IOException { | |
agg_hashMap = agg_plan.createHashMap(); | |
////////// FileSourceScanExec#doProduce START ////////// | |
while (scan_input.hasNext()) { | |
InternalRow scan_row = (InternalRow) scan_input.next(); | |
scan_numOutputRows.add(1); | |
////////// FilterExec#doConsume START ////////// | |
boolean scan_isNull2 = scan_row.isNullAt(0); | |
long scan_value2 = scan_isNull2 ? -1L : (scan_row.getLong(0)); | |
if (!(!(scan_isNull2))) continue; | |
boolean filter_isNull2 = false; | |
boolean filter_value2 = false; | |
filter_value2 = scan_value2 > 18L; | |
if (!filter_value2) continue; | |
////////// FilterExec#doConsume END ////////// | |
filter_numOutputRows.add(1); | |
boolean scan_isNull3 = scan_row.isNullAt(1); | |
UTF8String scan_value3 = scan_isNull3 ? null : (scan_row.getUTF8String(1)); | |
////////// HashAggregateExec#doConsumeWithKeys START ////////// | |
UnsafeRow agg_unsafeRowAggBuffer = null; | |
UnsafeRow agg_fastAggBuffer = null; | |
if (true) { | |
if (!scan_isNull3) { | |
agg_fastAggBuffer = agg_fastHashMap.findOrInsert(scan_value3); | |
} | |
} | |
if (agg_fastAggBuffer == null) { | |
// generate grouping key | |
agg_holder.reset(); | |
agg_rowWriter.zeroOutNullBytes(); | |
if (scan_isNull3) { | |
agg_rowWriter.setNullAt(0); | |
} else { | |
agg_rowWriter.write(0, scan_value3); | |
} | |
agg_result1.setTotalSize(agg_holder.totalSize()); | |
agg_value9 = 42; | |
if (!scan_isNull3) { | |
agg_value9 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes( | |
scan_value3.getBaseObject(), | |
scan_value3.getBaseOffset(), | |
scan_value3.numBytes(), agg_value9); | |
} | |
if (true) { | |
// try to get the buffer from hash map | |
agg_unsafeRowAggBuffer = | |
agg_hashMap.getAggregationBufferFromUnsafeRow(agg_result1, agg_value9); | |
} | |
if (agg_unsafeRowAggBuffer == null) { | |
if (agg_sorter == null) { | |
agg_sorter = agg_hashMap.destructAndCreateExternalSorter(); | |
} else { | |
agg_sorter.merge(agg_hashMap.destructAndCreateExternalSorter()); | |
} | |
// the hash map had be spilled, it should have enough memory now, | |
// try to allocate buffer again. | |
agg_unsafeRowAggBuffer = | |
agg_hashMap.getAggregationBufferFromUnsafeRow(agg_result1, agg_value9); | |
if (agg_unsafeRowAggBuffer == null) { | |
// failed to allocate the first page | |
throw new OutOfMemoryError("No enough memory for aggregation"); | |
} | |
} | |
} | |
if (agg_fastAggBuffer != null) { | |
// update fast row | |
// common sub-expressions | |
// evaluate aggregate function | |
boolean agg_isNull25 = false; | |
long agg_value29 = agg_fastAggBuffer.getLong(0); | |
long agg_value28 = -1L; | |
agg_value28 = agg_value29 + 1L; | |
boolean agg_isNull28 = true; | |
double agg_value31 = -1.0; | |
boolean agg_isNull29 = agg_fastAggBuffer.isNullAt(1); | |
double agg_value32 = agg_isNull29 ? -1.0 : (agg_fastAggBuffer.getDouble(1)); | |
if (!agg_isNull29) { | |
boolean agg_isNull31 = false; | |
double agg_value34 = -1.0; | |
if (!false) { | |
agg_value34 = (double) scan_value2; | |
} | |
boolean agg_isNull30 = agg_isNull31; | |
double agg_value33 = agg_value34; | |
if (agg_isNull30) { | |
boolean agg_isNull33 = false; | |
double agg_value36 = -1.0; | |
if (!false) { | |
agg_value36 = (double) 0; | |
} | |
if (!agg_isNull33) { | |
agg_isNull30 = false; | |
agg_value33 = agg_value36; | |
} | |
} | |
agg_isNull28 = false; // resultCode could change nullability. | |
agg_value31 = agg_value32 + agg_value33; | |
} | |
boolean agg_isNull35 = false; | |
long agg_value38 = -1L; | |
if (!false && false) { | |
boolean agg_isNull38 = agg_fastAggBuffer.isNullAt(2); | |
long agg_value41 = agg_isNull38 ? -1L : (agg_fastAggBuffer.getLong(2)); | |
agg_isNull35 = agg_isNull38; | |
agg_value38 = agg_value41; | |
} else { | |
boolean agg_isNull39 = true; | |
long agg_value42 = -1L; | |
boolean agg_isNull40 = agg_fastAggBuffer.isNullAt(2); | |
long agg_value43 = agg_isNull40 ? -1L : (agg_fastAggBuffer.getLong(2)); | |
if (!agg_isNull40) { | |
agg_isNull39 = false; // resultCode could change nullability. | |
agg_value42 = agg_value43 + 1L; | |
} | |
agg_isNull35 = agg_isNull39; | |
agg_value38 = agg_value42; | |
} | |
// update fast row | |
agg_fastAggBuffer.setLong(0, agg_value28); | |
if (!agg_isNull28) { | |
agg_fastAggBuffer.setDouble(1, agg_value31); | |
} else { | |
agg_fastAggBuffer.setNullAt(1); | |
} | |
if (!agg_isNull35) { | |
agg_fastAggBuffer.setLong(2, agg_value38); | |
} else { | |
agg_fastAggBuffer.setNullAt(2); | |
} | |
} else { | |
// update unsafe row | |
// common sub-expressions | |
// evaluate aggregate function | |
boolean agg_isNull8 = false; | |
long agg_value12 = agg_unsafeRowAggBuffer.getLong(0); | |
long agg_value11 = -1L; | |
agg_value11 = agg_value12 + 1L; | |
boolean agg_isNull11 = true; | |
double agg_value14 = -1.0; | |
boolean agg_isNull12 = agg_unsafeRowAggBuffer.isNullAt(1); | |
double agg_value15 = agg_isNull12 ? -1.0 : (agg_unsafeRowAggBuffer.getDouble(1)); | |
if (!agg_isNull12) { | |
boolean agg_isNull14 = false; | |
double agg_value17 = -1.0; | |
if (!false) { | |
agg_value17 = (double) scan_value2; | |
} | |
boolean agg_isNull13 = agg_isNull14; | |
double agg_value16 = agg_value17; | |
if (agg_isNull13) { | |
boolean agg_isNull16 = false; | |
double agg_value19 = -1.0; | |
if (!false) { | |
agg_value19 = (double) 0; | |
} | |
if (!agg_isNull16) { | |
agg_isNull13 = false; | |
agg_value16 = agg_value19; | |
} | |
} | |
agg_isNull11 = false; // resultCode could change nullability. | |
agg_value14 = agg_value15 + agg_value16; | |
} | |
boolean agg_isNull18 = false; | |
long agg_value21 = -1L; | |
if (!false && false) { | |
boolean agg_isNull21 = agg_unsafeRowAggBuffer.isNullAt(2); | |
long agg_value24 = agg_isNull21 ? -1L : (agg_unsafeRowAggBuffer.getLong(2)); | |
agg_isNull18 = agg_isNull21; | |
agg_value21 = agg_value24; | |
} else { | |
boolean agg_isNull22 = true; | |
long agg_value25 = -1L; | |
boolean agg_isNull23 = agg_unsafeRowAggBuffer.isNullAt(2); | |
long agg_value26 = agg_isNull23 ? -1L : (agg_unsafeRowAggBuffer.getLong(2)); | |
if (!agg_isNull23) { | |
agg_isNull22 = false; // resultCode could change nullability. | |
agg_value25 = agg_value26 + 1L; | |
} | |
agg_isNull18 = agg_isNull22; | |
agg_value21 = agg_value25; | |
} | |
// update unsafe row buffer | |
agg_unsafeRowAggBuffer.setLong(0, agg_value11); | |
if (!agg_isNull11) { | |
agg_unsafeRowAggBuffer.setDouble(1, agg_value14); | |
} else { | |
agg_unsafeRowAggBuffer.setNullAt(1); | |
} | |
if (!agg_isNull18) { | |
agg_unsafeRowAggBuffer.setLong(2, agg_value21); | |
} else { | |
agg_unsafeRowAggBuffer.setNullAt(2); | |
} | |
} | |
////////// HashAggregateExec#doConsumeWithKeys END ////////// | |
if (shouldStop()) return; | |
} | |
////////// FileSourceScanExec#doProduce END ////////// | |
agg_fastHashMapIter = agg_fastHashMap.rowIterator(); | |
agg_mapIter = agg_plan.finishAggregate(agg_hashMap, agg_sorter, agg_peakMemory, agg_spillSize); | |
} | |
private void wholestagecodegen_init_1() { | |
agg_result1 = new UnsafeRow(1); | |
this.agg_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result1, 32); | |
this.agg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder, 1); | |
agg_unsafeRowJoiner = agg_plan.createUnsafeJoiner(); | |
this.wholestagecodegen_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[5]; | |
this.wholestagecodegen_aggTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[6]; | |
} | |
protected void processNext() throws java.io.IOException { | |
////////// HashAggregateExec#doProduceWithKeys START ////////// | |
if (!agg_initAgg) { | |
agg_initAgg = true; | |
long wholestagecodegen_beforeAgg = System.nanoTime(); | |
agg_doAggregateWithKeys(); | |
wholestagecodegen_aggTime.add((System.nanoTime() - wholestagecodegen_beforeAgg) / 1000000); | |
} | |
// output the result | |
//// 包含调用 CodegenSupport#consume 产生的代码 | |
while (agg_fastHashMapIter.next()) { | |
wholestagecodegen_numOutputRows.add(1); | |
UnsafeRow agg_aggKey = (UnsafeRow) agg_fastHashMapIter.getKey(); | |
UnsafeRow agg_aggBuffer = (UnsafeRow) agg_fastHashMapIter.getValue(); | |
UnsafeRow agg_resultRow = agg_unsafeRowJoiner.join(agg_aggKey, agg_aggBuffer); | |
append(agg_resultRow); | |
if (shouldStop()) return; | |
} | |
agg_fastHashMap.close(); | |
while (agg_mapIter.next()) { | |
wholestagecodegen_numOutputRows.add(1); | |
UnsafeRow agg_aggKey = (UnsafeRow) agg_mapIter.getKey(); | |
UnsafeRow agg_aggBuffer = (UnsafeRow) agg_mapIter.getValue(); | |
UnsafeRow agg_resultRow = agg_unsafeRowJoiner.join(agg_aggKey, agg_aggBuffer); | |
append(agg_resultRow); | |
if (shouldStop()) return; | |
} | |
agg_mapIter.close(); | |
if (agg_sorter == null) { | |
agg_hashMap.free(); | |
} | |
////////// HashAggregateExec#doProduceWithKeys END ////////// | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment