Skip to content

Instantly share code, notes, and snippets.

@egalpin
Last active May 24, 2023 20:25
Show Gist options
  • Save egalpin/162a04b896dc7be1d0899acf17e676b3 to your computer and use it in GitHub Desktop.
Save egalpin/162a04b896dc7be1d0899acf17e676b3 to your computer and use it in GitHub Desktop.
Apache Beam RateLimit
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
import com.google.auto.value.AutoValue;
import java.io.Serializable;
import java.util.concurrent.TimeUnit;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.state.CombiningState;
import org.apache.beam.sdk.state.State;
import org.apache.beam.sdk.state.StateSpec;
import org.apache.beam.sdk.state.StateSpecs;
import org.apache.beam.sdk.state.TimeDomain;
import org.apache.beam.sdk.state.Timer;
import org.apache.beam.sdk.state.TimerSpec;
import org.apache.beam.sdk.state.TimerSpecs;
import org.apache.beam.sdk.state.ValueState;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TimestampedValue;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.MoreObjects;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.Uninterruptibles;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class RateLimit<K, InputT>
extends PTransform<PCollection<KV<K, InputT>>, PCollection<KV<K, InputT>>> {
@AutoValue
public abstract static class LimitParams<InputT> implements Serializable {
public static <InputT> LimitParams<InputT> create(
long batchSize,
long batchSizeBytes,
SerializableFunction<InputT, Long> elementByteSize,
Duration limitInterval) {
return new AutoValue_RateLimit_LimitParams<>(
batchSize, batchSizeBytes, elementByteSize, limitInterval);
}
public abstract long getBatchSize();
public abstract long getBatchSizeBytes();
@Nullable
public abstract SerializableFunction<InputT, Long> getElementByteSize();
public abstract Duration getIntervalDuration();
public SerializableFunction<InputT, Long> getWeigher(Coder<InputT> valueCoder) {
SerializableFunction<InputT, Long> weigher = getElementByteSize();
if (getBatchSizeBytes() < Long.MAX_VALUE) {
if (weigher == null) {
// If the user didn't specify a byte-size function, then use the Coder to determine the
// byte
// size.
// Note: if Coder.isRegisterByteSizeObserverCheap == false, then this will be expensive.
weigher =
(InputT element) -> {
try {
ByteSizeObserver observer = new ByteSizeObserver();
valueCoder.registerByteSizeObserver(element, observer);
observer.advance();
return observer.getElementByteSize();
} catch (Exception e) {
throw new RuntimeException(e);
}
};
}
}
return weigher;
}
}
private final LimitParams<InputT> params;
private RateLimit(LimitParams<InputT> params) {
this.params = params;
}
/** Aim to create batches each with the specified element count. */
public static <K, InputT> RateLimit<K, InputT> ofSize(long batchSize) {
Preconditions.checkState(batchSize < Long.MAX_VALUE);
return new RateLimit<>(LimitParams.create(batchSize, Long.MAX_VALUE, null, Duration.ZERO));
}
/**
* Aim to create batches each with the specified byte size.
*
* <p>This option uses the PCollection's coder to determine the byte size of each element. This
* may not always be what is desired (e.g. the encoded size is not the same as the memory usage of
* the Java object). This is also only recommended if the coder returns true for
* isRegisterByteSizeObserverCheap, otherwise the transform will perform a possibly-expensive
* encoding of each element in order to measure its byte size. An alternate approach is to use
* {@link #ofByteSize(long, SerializableFunction)} to specify code to calculate the byte size.
*/
public static <K, InputT> RateLimit<K, InputT> ofByteSize(long batchSizeBytes) {
Preconditions.checkState(batchSizeBytes < Long.MAX_VALUE);
return new RateLimit<>(LimitParams.create(Long.MAX_VALUE, batchSizeBytes, null, Duration.ZERO));
}
/**
* Aim to create batches each with the specified byte size. The provided function is used to
* determine the byte size of each element.
*/
public static <K, InputT> RateLimit<K, InputT> ofByteSize(
long batchSizeBytes, SerializableFunction<InputT, Long> getElementByteSize) {
Preconditions.checkState(batchSizeBytes < Long.MAX_VALUE);
return new RateLimit<>(
LimitParams.create(Long.MAX_VALUE, batchSizeBytes, getElementByteSize, Duration.ZERO));
}
public RateLimit<K, InputT> per(Duration intervalDuration) {
return new RateLimit<>(
LimitParams.create(
params.getBatchSize(),
params.getBatchSizeBytes(),
params.getElementByteSize(),
intervalDuration));
}
/** Returns user supplied parameters for batching. */
public LimitParams<InputT> getLimitParams() {
return params;
}
private static class ByteSizeObserver extends ElementByteSizeObserver {
private long elementByteSize = 0;
@Override
protected void reportElementSize(long elementByteSize) {
this.elementByteSize += elementByteSize;
}
public long getElementByteSize() {
return this.elementByteSize;
}
}
@Override
public PCollection<KV<K, InputT>> expand(PCollection<KV<K, InputT>> input) {
checkArgument(
input.getCoder() instanceof KvCoder,
"coder specified in the input PCollection is not a KvCoder");
checkArgument(
params.getIntervalDuration() != Duration.ZERO, "RateLimit interval may not be null");
KvCoder<K, InputT> inputCoder = (KvCoder<K, InputT>) input.getCoder();
final Coder<InputT> valueCoder = (Coder<InputT>) inputCoder.getCoderArguments().get(1);
SerializableFunction<InputT, Long> weigher = params.getWeigher(valueCoder);
return input.apply(
ParDo.of(
new RateLimitFn<>(
params.getBatchSize(),
params.getBatchSizeBytes(),
weigher,
params.getIntervalDuration(),
valueCoder)));
}
@VisibleForTesting
private static class RateLimitFn<K, InputT> extends DoFn<KV<K, InputT>, KV<K, InputT>> {
private static final Logger LOG = LoggerFactory.getLogger(RateLimitFn.class);
private final long maxElements;
private final long maxBytes;
@Nullable private final SerializableFunction<InputT, Long> weigher;
private final Duration intervalDuration;
// This timer expires when it's time to batch and output the buffered data.
private static final String END_OF_INTERVAL_ID = "endOfInterval";
@TimerId(END_OF_INTERVAL_ID)
private final TimerSpec intervalTimer = TimerSpecs.timer(TimeDomain.PROCESSING_TIME);
// The set of elements that will go in the next batch.
private static final String BUFFER_ID = "buffer";
@StateId(BUFFER_ID)
private final StateSpec<BagState<TimestampedValue<InputT>>> bufferSpec;
// Num elements over and above the rate limit that needed to be buffered for later output
private static final String NUM_ELEMENTS_BUFFERED_ID = "numElementsBuffered";
@StateId(NUM_ELEMENTS_BUFFERED_ID)
private final StateSpec<CombiningState<Long, long[], Long>> bufferSizeSpec;
// The cumulative number of elements since the start of the interval
private static final String NUM_ELEMENTS_IN_INTERVAL_ID = "numElementsInInterval";
@StateId(NUM_ELEMENTS_IN_INTERVAL_ID)
private final StateSpec<CombiningState<Long, long[], Long>> intervalSizeSpec;
// The cumulative number of bytes since the start of the interval
private static final String NUM_BYTES_IN_INTERVAL_ID = "numBytesInInterval";
@StateId(NUM_BYTES_IN_INTERVAL_ID)
private final StateSpec<CombiningState<Long, long[], Long>> intervalSizeBytesSpec;
// The timestamp of the current active timer.
private static final String TIMER_TIMESTAMP = "timerTs";
@StateId(TIMER_TIMESTAMP)
private final StateSpec<ValueState<Long>> timerTsSpec;
// The minimum element timestamp currently buffered in the bag. This is used to set the output
// timestamp
// on the timer which ensures that the watermark correctly tracks the buffered elements.
private static final String MIN_OBSERVED_TS = "minBufferedTs";
@StateId(MIN_OBSERVED_TS)
private final StateSpec<CombiningState<Long, long[], Long>> minBufferedTsSpec;
RateLimitFn(
long maxElements,
long maxBytes,
@Nullable SerializableFunction<InputT, Long> weigher,
Duration intervalDuration,
Coder<InputT> inputValueCoder) {
this.maxElements = maxElements;
this.maxBytes = maxBytes;
this.weigher = weigher;
this.intervalDuration = intervalDuration;
this.bufferSpec = StateSpecs.bag(TimestampedValue.TimestampedValueCoder.of(inputValueCoder));
Combine.BinaryCombineLongFn sumCombineFn =
new Combine.BinaryCombineLongFn() {
@Override
public long identity() {
return 0L;
}
@Override
public long apply(long left, long right) {
return left + right;
}
};
Combine.BinaryCombineLongFn minCombineFn =
new Combine.BinaryCombineLongFn() {
@Override
public long identity() {
return BoundedWindow.TIMESTAMP_MAX_VALUE.getMillis();
}
@Override
public long apply(long left, long right) {
return Math.min(left, right);
}
};
this.intervalSizeSpec = StateSpecs.combining(sumCombineFn);
this.intervalSizeBytesSpec = StateSpecs.combining(sumCombineFn);
this.timerTsSpec = StateSpecs.value();
this.minBufferedTsSpec = StateSpecs.combining(minCombineFn);
this.bufferSizeSpec = StateSpecs.combining(sumCombineFn);
}
@ProcessElement
public void processElement(
@TimerId(END_OF_INTERVAL_ID) Timer intervalTimer,
@StateId(BUFFER_ID) BagState<TimestampedValue<InputT>> buffer,
@StateId(NUM_ELEMENTS_BUFFERED_ID)
CombiningState<Long, long[], Long> numBuffered,
@AlwaysFetched @StateId(NUM_ELEMENTS_IN_INTERVAL_ID)
CombiningState<Long, long[], Long> numOutputInInterval,
@AlwaysFetched @StateId(NUM_BYTES_IN_INTERVAL_ID)
CombiningState<Long, long[], Long> numBytesOutputInInterval,
@StateId(TIMER_TIMESTAMP) ValueState<Long> timerTs,
@StateId(MIN_OBSERVED_TS) CombiningState<Long, long[], Long> minBufferedTs,
@Element KV<K, InputT> element,
@Timestamp Instant elementTs,
BoundedWindow window,
OutputReceiver<KV<K, InputT>> receiver) {
outputAndObserve(
intervalTimer,
buffer,
numBuffered,
numOutputInInterval,
numBytesOutputInInterval,
timerTs,
minBufferedTs,
element,
elementTs,
receiver);
}
@OnTimer(END_OF_INTERVAL_ID)
public void onIntervalTimer(
OutputReceiver<KV<K, InputT>> receiver,
@Timestamp Instant timestamp,
@Key K key,
@AlwaysFetched @StateId(BUFFER_ID) BagState<TimestampedValue<InputT>> buffer,
@StateId(NUM_ELEMENTS_BUFFERED_ID)
CombiningState<Long, long[], Long> numBuffered,
@StateId(NUM_ELEMENTS_IN_INTERVAL_ID)
CombiningState<Long, long[], Long> numOutputInInterval,
@StateId(NUM_BYTES_IN_INTERVAL_ID)
CombiningState<Long, long[], Long> numBytesOutputInInterval,
@StateId(TIMER_TIMESTAMP) ValueState<Long> timerTs,
@StateId(MIN_OBSERVED_TS) CombiningState<Long, long[], Long> minBufferedTs,
@TimerId(END_OF_INTERVAL_ID) Timer intervalTimer) {
LOG.debug(
"*** END OF INTERVAL *** for timer timestamp {} with buffering duration {}",
timestamp.toInstant(),
intervalDuration.getMillis());
Iterable<TimestampedValue<InputT>> batch = buffer.read();
clearState(buffer, numBuffered, numOutputInInterval, numBytesOutputInInterval, minBufferedTs);
flushBatch(
receiver,
key,
intervalTimer,
batch,
buffer,
numBuffered,
numOutputInInterval,
numBytesOutputInInterval,
timerTs,
minBufferedTs);
}
@OnWindowExpiration
public void onWindowExpiration(
OutputReceiver<KV<K, InputT>> receiver,
@Key K key,
@AlwaysFetched @StateId(BUFFER_ID) BagState<TimestampedValue<InputT>> buffer,
@AlwaysFetched @StateId(NUM_ELEMENTS_BUFFERED_ID)
CombiningState<Long, long[], Long> numBuffered,
@StateId(NUM_ELEMENTS_IN_INTERVAL_ID)
CombiningState<Long, long[], Long> numOutputInInterval,
@StateId(NUM_BYTES_IN_INTERVAL_ID)
CombiningState<Long, long[], Long> numBytesOutputInInterval,
@StateId(TIMER_TIMESTAMP) ValueState<Long> timerTs,
@StateId(MIN_OBSERVED_TS) CombiningState<Long, long[], Long> minBufferedTs) {
// This code will only be invoked when pipelines are cancelled and watermark is
// fast-forward to max value
LOG.debug(
"*** END OF WINDOW *** for timer timestamp {} with buffering duration {}",
Instant.ofEpochMilli(timerTs.read()),
intervalDuration.getMillis());
Instant start = null;
while (numBuffered.read() > 0) {
// Since we have no way of knowing if the limit was hit just prior, wait one
// interval before outputting
Instant now = Instant.now();
long delayMs =
intervalDuration.getMillis()
- (now.minus(MoreObjects.firstNonNull(start, now).getMillis()).getMillis());
Uninterruptibles.sleepUninterruptibly(Math.max(0L, delayMs), TimeUnit.MILLISECONDS);
start = Instant.now();
Iterable<TimestampedValue<InputT>> batch = buffer.read();
clearState(
buffer, numBuffered, numOutputInInterval, numBytesOutputInInterval, minBufferedTs);
flushBatch(
receiver,
key,
null,
batch,
buffer,
numBuffered,
numOutputInInterval,
numBytesOutputInInterval,
timerTs,
minBufferedTs);
}
}
private void flushBatch(
OutputReceiver<KV<K, InputT>> receiver,
K key,
Timer intervalTimer,
Iterable<TimestampedValue<InputT>> batch,
BagState<TimestampedValue<InputT>> elementBuffer,
CombiningState<Long, long[], Long> numBuffered,
CombiningState<Long, long[], Long> numOutputInInterval,
CombiningState<Long, long[], Long> numBytesOutputInInterval,
ValueState<Long> timerTs,
CombiningState<Long, long[], Long> minBufferedTs) {
// When the timer fires, elementBuffer state might be empty
if (!Iterables.isEmpty(batch)) {
for (TimestampedValue<InputT> val : batch) {
outputAndObserve(
intervalTimer,
elementBuffer,
numBuffered,
numOutputInInterval,
numBytesOutputInInterval,
timerTs,
minBufferedTs,
KV.of(key, val.getValue()),
val.getTimestamp(),
receiver);
}
}
}
private void outputAndObserve(
@Nullable Timer intervalTimer,
BagState<TimestampedValue<InputT>> elementBuffer,
CombiningState<Long, long[], Long> numBuffered,
CombiningState<Long, long[], Long> numOutputInInterval,
CombiningState<Long, long[], Long> numBytesOutputInInterval,
ValueState<Long> timerTs,
CombiningState<Long, long[], Long> minBufferedTs,
KV<K, InputT> element,
Instant elementTs,
OutputReceiver<KV<K, InputT>> receiver) {
numOutputInInterval.readLater();
numBytesOutputInInterval.readLater();
long numElementsSoFar = numOutputInInterval.read();
if (numElementsSoFar == 0L && intervalTimer != null) {
setTimerFiringTime(intervalTimer, timerTs);
}
if (numElementsSoFar >= maxElements || numBytesOutputInInterval.read() >= maxBytes) {
// We can't output, we've hit the limit for the interval already. Buffer this
// element into the bag and move on since it will be flushed later when there is
// capacity under the rate limit
elementBuffer.add(TimestampedValue.of(element.getValue(), elementTs));
numBuffered.add(1L);
long oldOutputTs =
MoreObjects.firstNonNull(
minBufferedTs.read(), BoundedWindow.TIMESTAMP_MAX_VALUE.getMillis());
// Blind add is supported with combiningState
minBufferedTs.add(elementTs.getMillis());
if (minBufferedTs.read() != oldOutputTs && intervalTimer != null) {
// we need to update the timer's output timestamp
updateTimerOutputTimestamp(intervalTimer, timerTs, minBufferedTs);
}
return;
}
receiver.output(element);
LOG.debug("*** OUTPUT DOWNSTREAM *** {} for timer timestamp {}", element, Instant.now());
// Blind add is supported with combiningState
numOutputInInterval.add(1L);
if (weigher != null) {
// Blind add is supported with combiningState
numBytesOutputInInterval.add(weigher.apply(element.getValue()));
}
}
private void updateTimerOutputTimestamp(
Timer intervalTimer,
ValueState<Long> timerTs,
CombiningState<Long, long[], Long> minBufferedTs) {
// Update the output timestamp, but not the firing time
intervalTimer
.withOutputTimestamp(Instant.ofEpochMilli(minBufferedTs.read()))
.set(Instant.ofEpochMilli(timerTs.read()));
}
private void setTimerFiringTime(Timer intervalTimer, ValueState<Long> timerTs) {
long triggerTs =
Math.min(
Instant.now().plus(intervalDuration.getMillis()).getMillis(),
BoundedWindow.TIMESTAMP_MAX_VALUE.minus(Duration.standardDays(1)).getMillis());
LOG.debug("Setting timer to trigger at {}", Instant.ofEpochMilli(triggerTs));
intervalTimer.withNoOutputTimestamp().set(Instant.ofEpochMilli(triggerTs));
timerTs.write(triggerTs);
}
private void clearState(State... specs) {
for (State s : specs) {
s.clear();
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment