Skip to content

Instantly share code, notes, and snippets.

@kisimple
Last active January 4, 2020 14:36
Show Gist options
  • Save kisimple/63e72eb3bf516cc5b254fb4ac1ad5dc6 to your computer and use it in GitHub Desktop.
Save kisimple/63e72eb3bf516cc5b254fb4ac1ad5dc6 to your computer and use it in GitHub Desktop.
GeneratedClass for Partial aggregate
//// SELECT count(*), avg(age), gender FROM people WHERE age > 18 GROUP BY gender LIMIT 20
package org.apache.spark.sql.catalyst.expressions;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.types.UTF8String;
public class GeneratedClass extends
org.apache.spark.sql.catalyst.expressions.codegen.GeneratedClass {
public Object generate(Object[] references) {
return new GeneratedIterator(references);
}
final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
private Object[] references;
private scala.collection.Iterator[] inputs;
private boolean agg_initAgg;
private boolean agg_bufIsNull;
private long agg_bufValue;
private boolean agg_bufIsNull1;
private double agg_bufValue1;
private boolean agg_bufIsNull2;
private long agg_bufValue2;
private org.apache.spark.sql.execution.aggregate.HashAggregateExec agg_plan;
private agg_FastHashMap agg_fastHashMap;
private org.apache.spark.unsafe.KVIterator agg_fastHashMapIter;
private org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap agg_hashMap;
private org.apache.spark.sql.execution.UnsafeKVExternalSorter agg_sorter;
private org.apache.spark.unsafe.KVIterator agg_mapIter;
private org.apache.spark.sql.execution.metric.SQLMetric agg_peakMemory;
private org.apache.spark.sql.execution.metric.SQLMetric agg_spillSize;
private org.apache.spark.sql.execution.metric.SQLMetric scan_numOutputRows;
private scala.collection.Iterator scan_input;
private org.apache.spark.sql.execution.metric.SQLMetric filter_numOutputRows;
private UnsafeRow filter_result;
private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder filter_holder;
private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter filter_rowWriter;
private UnsafeRow project_result;
private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder project_holder;
private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter project_rowWriter;
private UnsafeRow agg_result1;
private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder;
private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter;
private int agg_value9;
private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowJoiner agg_unsafeRowJoiner;
private org.apache.spark.sql.execution.metric.SQLMetric wholestagecodegen_numOutputRows;
private org.apache.spark.sql.execution.metric.SQLMetric wholestagecodegen_aggTime;
public GeneratedIterator(Object[] references) {
this.references = references;
}
public void init(int index, scala.collection.Iterator[] inputs) {
partitionIndex = index;
this.inputs = inputs;
wholestagecodegen_init_0();
wholestagecodegen_init_1();
}
//// 初始化代码是如何生成的? TODO
private void wholestagecodegen_init_0() {
agg_initAgg = false;
this.agg_plan = (org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0];
agg_fastHashMap = new agg_FastHashMap(agg_plan.getTaskMemoryManager(), agg_plan.getEmptyAggregationBuffer());
this.agg_peakMemory = (org.apache.spark.sql.execution.metric.SQLMetric) references[1];
this.agg_spillSize = (org.apache.spark.sql.execution.metric.SQLMetric) references[2];
this.scan_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[3];
scan_input = inputs[0];
this.filter_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[4];
filter_result = new UnsafeRow(2);
this.filter_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(filter_result, 32);
this.filter_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(filter_holder, 2);
project_result = new UnsafeRow(2);
this.project_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(project_result, 32);
this.project_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(project_holder, 2);
}
//// 由 HashAggregateExec 添加的类
//// 调用 HashMapGenerator#generate 生成
public class agg_FastHashMap {
private org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch batch;
private int[] buckets;
private int capacity = 1 << 16;
private double loadFactor = 0.5;
private int numBuckets = (int) (capacity / loadFactor);
private int maxSteps = 2;
private int numRows = 0;
private org.apache.spark.sql.types.StructType keySchema = new org.apache.spark.sql.types.StructType()
.add("gender", org.apache.spark.sql.types.DataTypes.StringType);
private org.apache.spark.sql.types.StructType valueSchema = new org.apache.spark.sql.types.StructType()
.add("count", org.apache.spark.sql.types.DataTypes.LongType)
.add("sum", org.apache.spark.sql.types.DataTypes.DoubleType)
.add("count", org.apache.spark.sql.types.DataTypes.LongType);
private Object emptyVBase;
private long emptyVOff;
private int emptyVLen;
private boolean isBatchFull = false;
public agg_FastHashMap(
org.apache.spark.memory.TaskMemoryManager taskMemoryManager,
InternalRow emptyAggregationBuffer) {
batch = org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch
.allocate(keySchema, valueSchema, taskMemoryManager, capacity);
final UnsafeProjection valueProjection = UnsafeProjection.create(valueSchema);
final byte[] emptyBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes();
emptyVBase = emptyBuffer;
emptyVOff = Platform.BYTE_ARRAY_OFFSET;
emptyVLen = emptyBuffer.length;
buckets = new int[numBuckets];
java.util.Arrays.fill(buckets, -1);
}
public org.apache.spark.sql.catalyst.expressions.UnsafeRow findOrInsert(UTF8String agg_key) {
long h = hash(agg_key);
int step = 0;
int idx = (int) h & (numBuckets - 1);
while (step < maxSteps) {
// Return bucket index if it's either an empty slot or already contains the key
if (buckets[idx] == -1) {
if (numRows < capacity && !isBatchFull) {
// creating the unsafe for new entry
UnsafeRow agg_result = new UnsafeRow(1);
org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder
= new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, 32);
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter
= new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder, 1);
agg_holder.reset(); //TODO: investigate if reset or zeroout are actually needed
agg_rowWriter.zeroOutNullBytes();
agg_rowWriter.write(0, agg_key);
agg_result.setTotalSize(agg_holder.totalSize());
Object kbase = agg_result.getBaseObject();
long koff = agg_result.getBaseOffset();
int klen = agg_result.getSizeInBytes();
UnsafeRow vRow
= batch.appendRow(kbase, koff, klen, emptyVBase, emptyVOff, emptyVLen);
if (vRow == null) {
isBatchFull = true;
} else {
buckets[idx] = numRows++;
}
return vRow;
} else {
// No more space
return null;
}
} else if (equals(idx, agg_key)) {
return batch.getValueRow(buckets[idx]);
}
idx = (idx + 1) & (numBuckets - 1);
step++;
}
// Didn't find it
return null;
}
private boolean equals(int idx, UTF8String agg_key) {
UnsafeRow row = batch.getKeyRow(buckets[idx]);
return (row.getUTF8String(0).equals(agg_key));
}
private long hash(UTF8String agg_key) {
long agg_hash = 0;
int agg_result = 0;
byte[] agg_bytes = agg_key.getBytes();
for (int i = 0; i < agg_bytes.length; i++) {
int agg_hash1 = agg_bytes[i];
agg_result = (agg_result ^ (0x9e3779b9)) + agg_hash1 + (agg_result << 6) + (agg_result >>> 2);
}
agg_hash = (agg_hash ^ (0x9e3779b9)) + agg_result + (agg_hash << 6) + (agg_hash >>> 2);
return agg_hash;
}
public org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> rowIterator() {
return batch.rowIterator();
}
public void close() {
batch.close();
}
}
//// 由 HashAggregateExec 添加的函数
//// 会调用 CodegenSupport#produce 来生成代码
private void agg_doAggregateWithKeys() throws java.io.IOException {
agg_hashMap = agg_plan.createHashMap();
////////// FileSourceScanExec#doProduce START //////////
while (scan_input.hasNext()) {
InternalRow scan_row = (InternalRow) scan_input.next();
scan_numOutputRows.add(1);
////////// FilterExec#doConsume START //////////
boolean scan_isNull2 = scan_row.isNullAt(0);
long scan_value2 = scan_isNull2 ? -1L : (scan_row.getLong(0));
if (!(!(scan_isNull2))) continue;
boolean filter_isNull2 = false;
boolean filter_value2 = false;
filter_value2 = scan_value2 > 18L;
if (!filter_value2) continue;
////////// FilterExec#doConsume END //////////
filter_numOutputRows.add(1);
boolean scan_isNull3 = scan_row.isNullAt(1);
UTF8String scan_value3 = scan_isNull3 ? null : (scan_row.getUTF8String(1));
////////// HashAggregateExec#doConsumeWithKeys START //////////
UnsafeRow agg_unsafeRowAggBuffer = null;
UnsafeRow agg_fastAggBuffer = null;
if (true) {
if (!scan_isNull3) {
agg_fastAggBuffer = agg_fastHashMap.findOrInsert(scan_value3);
}
}
if (agg_fastAggBuffer == null) {
// generate grouping key
agg_holder.reset();
agg_rowWriter.zeroOutNullBytes();
if (scan_isNull3) {
agg_rowWriter.setNullAt(0);
} else {
agg_rowWriter.write(0, scan_value3);
}
agg_result1.setTotalSize(agg_holder.totalSize());
agg_value9 = 42;
if (!scan_isNull3) {
agg_value9 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(
scan_value3.getBaseObject(),
scan_value3.getBaseOffset(),
scan_value3.numBytes(), agg_value9);
}
if (true) {
// try to get the buffer from hash map
agg_unsafeRowAggBuffer =
agg_hashMap.getAggregationBufferFromUnsafeRow(agg_result1, agg_value9);
}
if (agg_unsafeRowAggBuffer == null) {
if (agg_sorter == null) {
agg_sorter = agg_hashMap.destructAndCreateExternalSorter();
} else {
agg_sorter.merge(agg_hashMap.destructAndCreateExternalSorter());
}
// the hash map had be spilled, it should have enough memory now,
// try to allocate buffer again.
agg_unsafeRowAggBuffer =
agg_hashMap.getAggregationBufferFromUnsafeRow(agg_result1, agg_value9);
if (agg_unsafeRowAggBuffer == null) {
// failed to allocate the first page
throw new OutOfMemoryError("No enough memory for aggregation");
}
}
}
if (agg_fastAggBuffer != null) {
// update fast row
// common sub-expressions
// evaluate aggregate function
boolean agg_isNull25 = false;
long agg_value29 = agg_fastAggBuffer.getLong(0);
long agg_value28 = -1L;
agg_value28 = agg_value29 + 1L;
boolean agg_isNull28 = true;
double agg_value31 = -1.0;
boolean agg_isNull29 = agg_fastAggBuffer.isNullAt(1);
double agg_value32 = agg_isNull29 ? -1.0 : (agg_fastAggBuffer.getDouble(1));
if (!agg_isNull29) {
boolean agg_isNull31 = false;
double agg_value34 = -1.0;
if (!false) {
agg_value34 = (double) scan_value2;
}
boolean agg_isNull30 = agg_isNull31;
double agg_value33 = agg_value34;
if (agg_isNull30) {
boolean agg_isNull33 = false;
double agg_value36 = -1.0;
if (!false) {
agg_value36 = (double) 0;
}
if (!agg_isNull33) {
agg_isNull30 = false;
agg_value33 = agg_value36;
}
}
agg_isNull28 = false; // resultCode could change nullability.
agg_value31 = agg_value32 + agg_value33;
}
boolean agg_isNull35 = false;
long agg_value38 = -1L;
if (!false && false) {
boolean agg_isNull38 = agg_fastAggBuffer.isNullAt(2);
long agg_value41 = agg_isNull38 ? -1L : (agg_fastAggBuffer.getLong(2));
agg_isNull35 = agg_isNull38;
agg_value38 = agg_value41;
} else {
boolean agg_isNull39 = true;
long agg_value42 = -1L;
boolean agg_isNull40 = agg_fastAggBuffer.isNullAt(2);
long agg_value43 = agg_isNull40 ? -1L : (agg_fastAggBuffer.getLong(2));
if (!agg_isNull40) {
agg_isNull39 = false; // resultCode could change nullability.
agg_value42 = agg_value43 + 1L;
}
agg_isNull35 = agg_isNull39;
agg_value38 = agg_value42;
}
// update fast row
agg_fastAggBuffer.setLong(0, agg_value28);
if (!agg_isNull28) {
agg_fastAggBuffer.setDouble(1, agg_value31);
} else {
agg_fastAggBuffer.setNullAt(1);
}
if (!agg_isNull35) {
agg_fastAggBuffer.setLong(2, agg_value38);
} else {
agg_fastAggBuffer.setNullAt(2);
}
} else {
// update unsafe row
// common sub-expressions
// evaluate aggregate function
boolean agg_isNull8 = false;
long agg_value12 = agg_unsafeRowAggBuffer.getLong(0);
long agg_value11 = -1L;
agg_value11 = agg_value12 + 1L;
boolean agg_isNull11 = true;
double agg_value14 = -1.0;
boolean agg_isNull12 = agg_unsafeRowAggBuffer.isNullAt(1);
double agg_value15 = agg_isNull12 ? -1.0 : (agg_unsafeRowAggBuffer.getDouble(1));
if (!agg_isNull12) {
boolean agg_isNull14 = false;
double agg_value17 = -1.0;
if (!false) {
agg_value17 = (double) scan_value2;
}
boolean agg_isNull13 = agg_isNull14;
double agg_value16 = agg_value17;
if (agg_isNull13) {
boolean agg_isNull16 = false;
double agg_value19 = -1.0;
if (!false) {
agg_value19 = (double) 0;
}
if (!agg_isNull16) {
agg_isNull13 = false;
agg_value16 = agg_value19;
}
}
agg_isNull11 = false; // resultCode could change nullability.
agg_value14 = agg_value15 + agg_value16;
}
boolean agg_isNull18 = false;
long agg_value21 = -1L;
if (!false && false) {
boolean agg_isNull21 = agg_unsafeRowAggBuffer.isNullAt(2);
long agg_value24 = agg_isNull21 ? -1L : (agg_unsafeRowAggBuffer.getLong(2));
agg_isNull18 = agg_isNull21;
agg_value21 = agg_value24;
} else {
boolean agg_isNull22 = true;
long agg_value25 = -1L;
boolean agg_isNull23 = agg_unsafeRowAggBuffer.isNullAt(2);
long agg_value26 = agg_isNull23 ? -1L : (agg_unsafeRowAggBuffer.getLong(2));
if (!agg_isNull23) {
agg_isNull22 = false; // resultCode could change nullability.
agg_value25 = agg_value26 + 1L;
}
agg_isNull18 = agg_isNull22;
agg_value21 = agg_value25;
}
// update unsafe row buffer
agg_unsafeRowAggBuffer.setLong(0, agg_value11);
if (!agg_isNull11) {
agg_unsafeRowAggBuffer.setDouble(1, agg_value14);
} else {
agg_unsafeRowAggBuffer.setNullAt(1);
}
if (!agg_isNull18) {
agg_unsafeRowAggBuffer.setLong(2, agg_value21);
} else {
agg_unsafeRowAggBuffer.setNullAt(2);
}
}
////////// HashAggregateExec#doConsumeWithKeys END //////////
if (shouldStop()) return;
}
////////// FileSourceScanExec#doProduce END //////////
agg_fastHashMapIter = agg_fastHashMap.rowIterator();
agg_mapIter = agg_plan.finishAggregate(agg_hashMap, agg_sorter, agg_peakMemory, agg_spillSize);
}
private void wholestagecodegen_init_1() {
agg_result1 = new UnsafeRow(1);
this.agg_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result1, 32);
this.agg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder, 1);
agg_unsafeRowJoiner = agg_plan.createUnsafeJoiner();
this.wholestagecodegen_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[5];
this.wholestagecodegen_aggTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[6];
}
protected void processNext() throws java.io.IOException {
////////// HashAggregateExec#doProduceWithKeys START //////////
if (!agg_initAgg) {
agg_initAgg = true;
long wholestagecodegen_beforeAgg = System.nanoTime();
agg_doAggregateWithKeys();
wholestagecodegen_aggTime.add((System.nanoTime() - wholestagecodegen_beforeAgg) / 1000000);
}
// output the result
//// 包含调用 CodegenSupport#consume 产生的代码
while (agg_fastHashMapIter.next()) {
wholestagecodegen_numOutputRows.add(1);
UnsafeRow agg_aggKey = (UnsafeRow) agg_fastHashMapIter.getKey();
UnsafeRow agg_aggBuffer = (UnsafeRow) agg_fastHashMapIter.getValue();
UnsafeRow agg_resultRow = agg_unsafeRowJoiner.join(agg_aggKey, agg_aggBuffer);
append(agg_resultRow);
if (shouldStop()) return;
}
agg_fastHashMap.close();
while (agg_mapIter.next()) {
wholestagecodegen_numOutputRows.add(1);
UnsafeRow agg_aggKey = (UnsafeRow) agg_mapIter.getKey();
UnsafeRow agg_aggBuffer = (UnsafeRow) agg_mapIter.getValue();
UnsafeRow agg_resultRow = agg_unsafeRowJoiner.join(agg_aggKey, agg_aggBuffer);
append(agg_resultRow);
if (shouldStop()) return;
}
agg_mapIter.close();
if (agg_sorter == null) {
agg_hashMap.free();
}
////////// HashAggregateExec#doProduceWithKeys END //////////
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment