Skip to content

Instantly share code, notes, and snippets.

@sjwiesman
Created August 24, 2017 15:16
Show Gist options
  • Save sjwiesman/19a9b1cfdafc5b4d426c2c9badb99679 to your computer and use it in GitHub Desktop.
Save sjwiesman/19a9b1cfdafc5b4d426c2c9badb99679 to your computer and use it in GitHub Desktop.
import org.apache.flink.api.common.functions.AggregateFunction;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.core.memory.DataInputViewStreamWrapper;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.runtime.state.StateSnapshotContext;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.watermark.Watermark;
import org.apache.flink.streaming.api.windowing.assigners.WindowAssigner;
import org.apache.flink.streaming.api.windowing.windows.Window;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.runtime.tasks.ProcessingTimeCallback;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
/**
* A couple of quick notes:
* - This operator try's to mimic the window operator as much as possible
* - State is required to be with the AggregateFunction
* - Output is a tuple containing the Key, Window, and Accumulator
* - After this operator the only valid operation is to key based on tuple.f0
* - The window operators assigner just forwards the window specified in tuple.f1
* but has to provide the window serializer for the underlying window type
* - In the Aggregation in the window operator runs aggregatefunction#merge
* - Because this uses operator state, on restore multiple aggregates for the same
* window may be returned to the same operator so they have to be merged
* - Elements are always forwarded ahead of the watermark so that records are not
* improperly labelled late.
* - A processing time timer is set to keep elements moving even if the watermark lags.
* In practice elements are forwarded at min(pause, watermark frequency). I use pause
* as a tuning parameter, the higher the pause value the more preaggregation but also
* the longer checkpoint times because checkpoints are not async
* - I wasn't sure how to get the TypeInformation of the window so I manually seriailze and
* deserialize
*/
public class PreAggregatorOperator<K, IN, ACC, W extends Window> extends AbstractStreamOperator<Tuple3<K, W, ACC>> implements OneInputStreamOperator<IN, Tuple3<K, W, ACC>>, ProcessingTimeCallback {
final private AggregateFunction<IN, ACC, ?> aggregateFunction;
final private KeySelector<IN, K> keySelector;
final private TupleTypeInfo<Tuple3<K, byte[], ACC>> typeInfo;
final private WindowAssigner<? super IN, W> windowAssigner;
final private TypeSerializer<W> windowSerializer;
final private long pause;
private Map<Tuple2<K, W>, ACC> aggregates;
private transient WindowAssigner.WindowAssignerContext windowAssignerContext;
transient private ListState<Tuple3<K, byte[], ACC>> keyValueListState;
public PreAggregatorOperator(
AggregateFunction<IN, ACC, ?> aggregateFunction,
KeySelector<IN, K> keySelector,
TypeInformation<K> keyTypeInformation,
TypeInformation<ACC> accTypeInformation,
TypeSerializer<W> windowSerializer,
WindowAssigner<? super IN, W> windowAssigner,
long pause
) {
this.aggregateFunction = aggregateFunction;
this.keySelector = keySelector;
// I was unsure how to best get byte information for the window so I just use the serializer manually
this.typeInfo = new TupleTypeInfo<>(keyTypeInformation, BasicArrayTypeInfo.BYTE_ARRAY_TYPE_INFO, accTypeInformation);
this.windowAssigner = windowAssigner;
this.windowSerializer = windowSerializer;
this.pause = pause;
}
@Override
public void open() throws Exception {
aggregates = new HashMap<>();
getProcessingTimeService().scheduleAtFixedRate(this, pause, pause);
windowAssignerContext = new WindowAssigner.WindowAssignerContext() {
@Override
public long getCurrentProcessingTime() {
return System.currentTimeMillis();
}
};
}
@Override
public void close() throws Exception {
getProcessingTimeService().shutdownService();
}
@Override
public void initializeState(StateInitializationContext context) throws Exception {
super.initializeState(context);
ListStateDescriptor<Tuple3<K, byte[], ACC>> kvDescriptor = new ListStateDescriptor<>("kv-state", typeInfo);
keyValueListState = context.getOperatorStateStore().getListState(kvDescriptor);
if (context.isRestored()) {
aggregates.clear();
for (Tuple3<K, byte[], ACC> tuple : keyValueListState.get()) {
Tuple2<K, W> key = new Tuple2<>(tuple.f0, windowSerializer.deserialize(new DataInputViewStreamWrapper(new ByteArrayInputStream(tuple.f1))));
ACC accumulator = aggregates.get(key);
if (accumulator != null) {
accumulator = aggregateFunction.merge(accumulator, tuple.f2);
}
aggregates.put(key, accumulator);
}
}
}
@Override
public void snapshotState(StateSnapshotContext context) throws Exception {
super.snapshotState(context);
for (Map.Entry<Tuple2<K, W>, ACC> entry : aggregates.entrySet()) {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputView dov = new DataOutputViewStreamWrapper(baos);
windowSerializer.serialize(entry.getKey().f1, dov);
baos.close();
keyValueListState.add(new Tuple3<>(entry.getKey().f0, baos.toByteArray(),entry.getValue()));
}
}
@Override
public void processElement(StreamRecord<IN> streamRecord) throws Exception {
Collection<W> windows = windowAssigner.assignWindows(streamRecord.getValue(), streamRecord.getTimestamp(), windowAssignerContext);
K key = keySelector.getKey(streamRecord.getValue());
for (W window : windows) {
Tuple2<K, W> tuple = new Tuple2<>(key, window);
ACC accumulator = aggregates.get(tuple);
if (accumulator == null) {
accumulator = aggregateFunction.createAccumulator();
}
aggregateFunction.add(streamRecord.getValue(), accumulator);
aggregates.put(tuple, accumulator);
}
}
@Override
public void onProcessingTime(long l) throws Exception {
forwardElements();
}
@Override
public void processWatermark(Watermark mark) throws Exception {
forwardElements();
super.processWatermark(mark);
}
private void forwardElements() {
for (Map.Entry<Tuple2<K, W>, ACC> entry : aggregates.entrySet()) {
Tuple3<K, W, ACC> tuple = new Tuple3<>(entry.getKey().f0, entry.getKey().f1, entry.getValue());
output.collect(new StreamRecord<>(tuple, tuple.f1.maxTimestamp()));
}
aggregates.clear();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment