Skip to content

Instantly share code, notes, and snippets.

@kisimple
Last active January 4, 2020 14:36
Show Gist options
  • Save kisimple/045153153161c1499cf254db17e5161b to your computer and use it in GitHub Desktop.
Save kisimple/045153153161c1499cf254db17e5161b to your computer and use it in GitHub Desktop.
GeneratedClass for Final aggregate
//// SELECT count(*), avg(age), gender FROM people WHERE age > 18 GROUP BY gender LIMIT 20
package org.apache.spark.sql.catalyst.expressions;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.unsafe.types.UTF8String;
public class GeneratedClass extends
org.apache.spark.sql.catalyst.expressions.codegen.GeneratedClass {
public Object generate(Object[] references) {
return new GeneratedIterator(references);
}
final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
private Object[] references;
private scala.collection.Iterator[] inputs;
private boolean agg_initAgg;
private boolean agg_bufIsNull;
private long agg_bufValue;
private boolean agg_bufIsNull1;
private double agg_bufValue1;
private boolean agg_bufIsNull2;
private long agg_bufValue2;
private org.apache.spark.sql.execution.aggregate.HashAggregateExec agg_plan;
private org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap agg_hashMap;
private org.apache.spark.sql.execution.UnsafeKVExternalSorter agg_sorter;
private org.apache.spark.unsafe.KVIterator agg_mapIter;
private org.apache.spark.sql.execution.metric.SQLMetric agg_peakMemory;
private org.apache.spark.sql.execution.metric.SQLMetric agg_spillSize;
private scala.collection.Iterator inputadapter_input;
private UnsafeRow agg_result;
private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder;
private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter;
private int agg_value9;
private UnsafeRow agg_result1;
private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder1;
private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter1;
private org.apache.spark.sql.execution.metric.SQLMetric wholestagecodegen_numOutputRows;
private org.apache.spark.sql.execution.metric.SQLMetric wholestagecodegen_aggTime;
public GeneratedIterator(Object[] references) {
this.references = references;
}
public void init(int index, scala.collection.Iterator[] inputs) {
partitionIndex = index;
this.inputs = inputs;
agg_initAgg = false;
this.agg_plan = (org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0];
this.agg_peakMemory = (org.apache.spark.sql.execution.metric.SQLMetric) references[1];
this.agg_spillSize = (org.apache.spark.sql.execution.metric.SQLMetric) references[2];
inputadapter_input = inputs[0];
agg_result = new UnsafeRow(1);
this.agg_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, 32);
this.agg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder, 1);
agg_result1 = new UnsafeRow(3);
this.agg_holder1 = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result1, 32);
this.agg_rowWriter1 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder1, 3);
this.wholestagecodegen_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[3];
this.wholestagecodegen_aggTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[4];
}
//// 由 HashAggregateExec 添加的函数
//// 会调用 CodegenSupport#produce 来生成代码
private void agg_doAggregateWithKeys() throws java.io.IOException {
agg_hashMap = agg_plan.createHashMap();
////////// InputAdapter#doProduce START //////////
while (inputadapter_input.hasNext()) {
InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
UTF8String inputadapter_value = inputadapter_isNull ? null : (inputadapter_row.getUTF8String(0));
long inputadapter_value1 = inputadapter_row.getLong(1);
boolean inputadapter_isNull2 = inputadapter_row.isNullAt(2);
double inputadapter_value2 = inputadapter_isNull2 ? -1.0 : (inputadapter_row.getDouble(2));
boolean inputadapter_isNull3 = inputadapter_row.isNullAt(3);
long inputadapter_value3 = inputadapter_isNull3 ? -1L : (inputadapter_row.getLong(3));
////////// HashAggregateExec#doConsumeWithKeys START //////////
UnsafeRow agg_unsafeRowAggBuffer = null;
UnsafeRow agg_fastAggBuffer = null;
if (agg_fastAggBuffer == null) {
// generate grouping key
agg_holder.reset();
agg_rowWriter.zeroOutNullBytes();
if (inputadapter_isNull) {
agg_rowWriter.setNullAt(0);
} else {
agg_rowWriter.write(0, inputadapter_value);
}
agg_result.setTotalSize(agg_holder.totalSize());
agg_value9 = 42;
if (!inputadapter_isNull) {
agg_value9 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(
inputadapter_value.getBaseObject(),
inputadapter_value.getBaseOffset(),
inputadapter_value.numBytes(), agg_value9);
}
if (true) {
// try to get the buffer from hash map
agg_unsafeRowAggBuffer =
agg_hashMap.getAggregationBufferFromUnsafeRow(agg_result, agg_value9);
}
if (agg_unsafeRowAggBuffer == null) {
if (agg_sorter == null) {
agg_sorter = agg_hashMap.destructAndCreateExternalSorter();
} else {
agg_sorter.merge(agg_hashMap.destructAndCreateExternalSorter());
}
// the hash map had be spilled, it should have enough memory now,
// try to allocate buffer again.
agg_unsafeRowAggBuffer =
agg_hashMap.getAggregationBufferFromUnsafeRow(agg_result, agg_value9);
if (agg_unsafeRowAggBuffer == null) {
// failed to allocate the first page
throw new OutOfMemoryError("No enough memory for aggregation");
}
}
}
if (agg_fastAggBuffer != null) {
// update fast row
} else {
// update unsafe row
// common sub-expressions
// evaluate aggregate function
boolean agg_isNull8 = false;
long agg_value12 = agg_unsafeRowAggBuffer.getLong(0);
long agg_value11 = -1L;
agg_value11 = agg_value12 + inputadapter_value1;
boolean agg_isNull11 = true;
double agg_value14 = -1.0;
boolean agg_isNull12 = agg_unsafeRowAggBuffer.isNullAt(1);
double agg_value15 = agg_isNull12 ? -1.0 : (agg_unsafeRowAggBuffer.getDouble(1));
if (!agg_isNull12) {
if (!inputadapter_isNull2) {
agg_isNull11 = false; // resultCode could change nullability.
agg_value14 = agg_value15 + inputadapter_value2;
}
}
boolean agg_isNull14 = true;
long agg_value17 = -1L;
boolean agg_isNull15 = agg_unsafeRowAggBuffer.isNullAt(2);
long agg_value18 = agg_isNull15 ? -1L : (agg_unsafeRowAggBuffer.getLong(2));
if (!agg_isNull15) {
if (!inputadapter_isNull3) {
agg_isNull14 = false; // resultCode could change nullability.
agg_value17 = agg_value18 + inputadapter_value3;
}
}
// update unsafe row buffer
agg_unsafeRowAggBuffer.setLong(0, agg_value11);
if (!agg_isNull11) {
agg_unsafeRowAggBuffer.setDouble(1, agg_value14);
} else {
agg_unsafeRowAggBuffer.setNullAt(1);
}
if (!agg_isNull14) {
agg_unsafeRowAggBuffer.setLong(2, agg_value17);
} else {
agg_unsafeRowAggBuffer.setNullAt(2);
}
}
////////// HashAggregateExec#doConsumeWithKeys END //////////
if (shouldStop()) return;
}
////////// InputAdapter#doProduce END //////////
agg_mapIter = agg_plan.finishAggregate(agg_hashMap, agg_sorter, agg_peakMemory, agg_spillSize);
}
protected void processNext() throws java.io.IOException {
////////// HashAggregateExec#doProduceWithKeys START //////////
if (!agg_initAgg) {
agg_initAgg = true;
long wholestagecodegen_beforeAgg = System.nanoTime();
agg_doAggregateWithKeys();
wholestagecodegen_aggTime.add((System.nanoTime() - wholestagecodegen_beforeAgg) / 1000000);
}
// output the result
//// 包含调用 CodegenSupport#consume 产生的代码
while (agg_mapIter.next()) {
wholestagecodegen_numOutputRows.add(1);
UnsafeRow agg_aggKey = (UnsafeRow) agg_mapIter.getKey();
UnsafeRow agg_aggBuffer = (UnsafeRow) agg_mapIter.getValue();
boolean agg_isNull17 = agg_aggKey.isNullAt(0);
UTF8String agg_value20 = agg_isNull17 ? null : (agg_aggKey.getUTF8String(0));
long agg_value21 = agg_aggBuffer.getLong(0);
boolean agg_isNull19 = agg_aggBuffer.isNullAt(1);
double agg_value22 = agg_isNull19 ? -1.0 : (agg_aggBuffer.getDouble(1));
boolean agg_isNull20 = agg_aggBuffer.isNullAt(2);
long agg_value23 = agg_isNull20 ? -1L : (agg_aggBuffer.getLong(2));
boolean agg_isNull25 = agg_isNull20;
double agg_value28 = -1.0;
if (!agg_isNull20) {
agg_value28 = (double) agg_value23;
}
boolean agg_isNull22 = false;
double agg_value25 = -1.0;
if (agg_isNull25 || agg_value28 == 0) {
agg_isNull22 = true;
} else {
boolean agg_isNull23 = agg_isNull19;
double agg_value26 = -1.0;
if (!agg_isNull19) {
agg_value26 = agg_value22;
}
if (agg_isNull23) {
agg_isNull22 = true;
} else {
agg_value25 = (double)(agg_value26 / agg_value28);
}
}
agg_holder1.reset();
agg_rowWriter1.zeroOutNullBytes();
agg_rowWriter1.write(0, agg_value21);
if (agg_isNull22) {
agg_rowWriter1.setNullAt(1);
} else {
agg_rowWriter1.write(1, agg_value25);
}
if (agg_isNull17) {
agg_rowWriter1.setNullAt(2);
} else {
agg_rowWriter1.write(2, agg_value20);
}
agg_result1.setTotalSize(agg_holder1.totalSize());
append(agg_result1);
if (shouldStop()) return;
}
agg_mapIter.close();
if (agg_sorter == null) {
agg_hashMap.free();
}
////////// HashAggregateExec#doProduceWithKeys END //////////
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment