Last active
May 24, 2023 20:25
-
-
Save egalpin/162a04b896dc7be1d0899acf17e676b3 to your computer and use it in GitHub Desktop.
Apache Beam RateLimit
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; | |
import com.google.auto.value.AutoValue; | |
import java.io.Serializable; | |
import java.util.concurrent.TimeUnit; | |
import org.apache.beam.sdk.coders.Coder; | |
import org.apache.beam.sdk.coders.KvCoder; | |
import org.apache.beam.sdk.state.BagState; | |
import org.apache.beam.sdk.state.CombiningState; | |
import org.apache.beam.sdk.state.State; | |
import org.apache.beam.sdk.state.StateSpec; | |
import org.apache.beam.sdk.state.StateSpecs; | |
import org.apache.beam.sdk.state.TimeDomain; | |
import org.apache.beam.sdk.state.Timer; | |
import org.apache.beam.sdk.state.TimerSpec; | |
import org.apache.beam.sdk.state.TimerSpecs; | |
import org.apache.beam.sdk.state.ValueState; | |
import org.apache.beam.sdk.transforms.Combine; | |
import org.apache.beam.sdk.transforms.DoFn; | |
import org.apache.beam.sdk.transforms.PTransform; | |
import org.apache.beam.sdk.transforms.ParDo; | |
import org.apache.beam.sdk.transforms.SerializableFunction; | |
import org.apache.beam.sdk.transforms.windowing.BoundedWindow; | |
import org.apache.beam.sdk.util.common.ElementByteSizeObserver; | |
import org.apache.beam.sdk.values.KV; | |
import org.apache.beam.sdk.values.PCollection; | |
import org.apache.beam.sdk.values.TimestampedValue; | |
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting; | |
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.MoreObjects; | |
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions; | |
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; | |
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.Uninterruptibles; | |
import org.checkerframework.checker.nullness.qual.Nullable; | |
import org.joda.time.Duration; | |
import org.joda.time.Instant; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
public class RateLimit<K, InputT> | |
extends PTransform<PCollection<KV<K, InputT>>, PCollection<KV<K, InputT>>> { | |
@AutoValue | |
public abstract static class LimitParams<InputT> implements Serializable { | |
public static <InputT> LimitParams<InputT> create( | |
long batchSize, | |
long batchSizeBytes, | |
SerializableFunction<InputT, Long> elementByteSize, | |
Duration limitInterval) { | |
return new AutoValue_RateLimit_LimitParams<>( | |
batchSize, batchSizeBytes, elementByteSize, limitInterval); | |
} | |
public abstract long getBatchSize(); | |
public abstract long getBatchSizeBytes(); | |
@Nullable | |
public abstract SerializableFunction<InputT, Long> getElementByteSize(); | |
public abstract Duration getIntervalDuration(); | |
public SerializableFunction<InputT, Long> getWeigher(Coder<InputT> valueCoder) { | |
SerializableFunction<InputT, Long> weigher = getElementByteSize(); | |
if (getBatchSizeBytes() < Long.MAX_VALUE) { | |
if (weigher == null) { | |
// If the user didn't specify a byte-size function, then use the Coder to determine the | |
// byte | |
// size. | |
// Note: if Coder.isRegisterByteSizeObserverCheap == false, then this will be expensive. | |
weigher = | |
(InputT element) -> { | |
try { | |
ByteSizeObserver observer = new ByteSizeObserver(); | |
valueCoder.registerByteSizeObserver(element, observer); | |
observer.advance(); | |
return observer.getElementByteSize(); | |
} catch (Exception e) { | |
throw new RuntimeException(e); | |
} | |
}; | |
} | |
} | |
return weigher; | |
} | |
} | |
private final LimitParams<InputT> params; | |
private RateLimit(LimitParams<InputT> params) { | |
this.params = params; | |
} | |
/** Aim to create batches each with the specified element count. */ | |
public static <K, InputT> RateLimit<K, InputT> ofSize(long batchSize) { | |
Preconditions.checkState(batchSize < Long.MAX_VALUE); | |
return new RateLimit<>(LimitParams.create(batchSize, Long.MAX_VALUE, null, Duration.ZERO)); | |
} | |
/** | |
* Aim to create batches each with the specified byte size. | |
* | |
* <p>This option uses the PCollection's coder to determine the byte size of each element. This | |
* may not always be what is desired (e.g. the encoded size is not the same as the memory usage of | |
* the Java object). This is also only recommended if the coder returns true for | |
* isRegisterByteSizeObserverCheap, otherwise the transform will perform a possibly-expensive | |
* encoding of each element in order to measure its byte size. An alternate approach is to use | |
* {@link #ofByteSize(long, SerializableFunction)} to specify code to calculate the byte size. | |
*/ | |
public static <K, InputT> RateLimit<K, InputT> ofByteSize(long batchSizeBytes) { | |
Preconditions.checkState(batchSizeBytes < Long.MAX_VALUE); | |
return new RateLimit<>(LimitParams.create(Long.MAX_VALUE, batchSizeBytes, null, Duration.ZERO)); | |
} | |
/** | |
* Aim to create batches each with the specified byte size. The provided function is used to | |
* determine the byte size of each element. | |
*/ | |
public static <K, InputT> RateLimit<K, InputT> ofByteSize( | |
long batchSizeBytes, SerializableFunction<InputT, Long> getElementByteSize) { | |
Preconditions.checkState(batchSizeBytes < Long.MAX_VALUE); | |
return new RateLimit<>( | |
LimitParams.create(Long.MAX_VALUE, batchSizeBytes, getElementByteSize, Duration.ZERO)); | |
} | |
public RateLimit<K, InputT> per(Duration intervalDuration) { | |
return new RateLimit<>( | |
LimitParams.create( | |
params.getBatchSize(), | |
params.getBatchSizeBytes(), | |
params.getElementByteSize(), | |
intervalDuration)); | |
} | |
/** Returns user supplied parameters for batching. */ | |
public LimitParams<InputT> getLimitParams() { | |
return params; | |
} | |
private static class ByteSizeObserver extends ElementByteSizeObserver { | |
private long elementByteSize = 0; | |
@Override | |
protected void reportElementSize(long elementByteSize) { | |
this.elementByteSize += elementByteSize; | |
} | |
public long getElementByteSize() { | |
return this.elementByteSize; | |
} | |
} | |
@Override | |
public PCollection<KV<K, InputT>> expand(PCollection<KV<K, InputT>> input) { | |
checkArgument( | |
input.getCoder() instanceof KvCoder, | |
"coder specified in the input PCollection is not a KvCoder"); | |
checkArgument( | |
params.getIntervalDuration() != Duration.ZERO, "RateLimit interval may not be null"); | |
KvCoder<K, InputT> inputCoder = (KvCoder<K, InputT>) input.getCoder(); | |
final Coder<InputT> valueCoder = (Coder<InputT>) inputCoder.getCoderArguments().get(1); | |
SerializableFunction<InputT, Long> weigher = params.getWeigher(valueCoder); | |
return input.apply( | |
ParDo.of( | |
new RateLimitFn<>( | |
params.getBatchSize(), | |
params.getBatchSizeBytes(), | |
weigher, | |
params.getIntervalDuration(), | |
valueCoder))); | |
} | |
@VisibleForTesting | |
private static class RateLimitFn<K, InputT> extends DoFn<KV<K, InputT>, KV<K, InputT>> { | |
private static final Logger LOG = LoggerFactory.getLogger(RateLimitFn.class); | |
private final long maxElements; | |
private final long maxBytes; | |
@Nullable private final SerializableFunction<InputT, Long> weigher; | |
private final Duration intervalDuration; | |
// This timer expires when it's time to batch and output the buffered data. | |
private static final String END_OF_INTERVAL_ID = "endOfInterval"; | |
@TimerId(END_OF_INTERVAL_ID) | |
private final TimerSpec intervalTimer = TimerSpecs.timer(TimeDomain.PROCESSING_TIME); | |
// The set of elements that will go in the next batch. | |
private static final String BUFFER_ID = "buffer"; | |
@StateId(BUFFER_ID) | |
private final StateSpec<BagState<TimestampedValue<InputT>>> bufferSpec; | |
// Num elements over and above the rate limit that needed to be buffered for later output | |
private static final String NUM_ELEMENTS_BUFFERED_ID = "numElementsBuffered"; | |
@StateId(NUM_ELEMENTS_BUFFERED_ID) | |
private final StateSpec<CombiningState<Long, long[], Long>> bufferSizeSpec; | |
// The cumulative number of elements since the start of the interval | |
private static final String NUM_ELEMENTS_IN_INTERVAL_ID = "numElementsInInterval"; | |
@StateId(NUM_ELEMENTS_IN_INTERVAL_ID) | |
private final StateSpec<CombiningState<Long, long[], Long>> intervalSizeSpec; | |
// The cumulative number of bytes since the start of the interval | |
private static final String NUM_BYTES_IN_INTERVAL_ID = "numBytesInInterval"; | |
@StateId(NUM_BYTES_IN_INTERVAL_ID) | |
private final StateSpec<CombiningState<Long, long[], Long>> intervalSizeBytesSpec; | |
// The timestamp of the current active timer. | |
private static final String TIMER_TIMESTAMP = "timerTs"; | |
@StateId(TIMER_TIMESTAMP) | |
private final StateSpec<ValueState<Long>> timerTsSpec; | |
// The minimum element timestamp currently buffered in the bag. This is used to set the output | |
// timestamp | |
// on the timer which ensures that the watermark correctly tracks the buffered elements. | |
private static final String MIN_OBSERVED_TS = "minBufferedTs"; | |
@StateId(MIN_OBSERVED_TS) | |
private final StateSpec<CombiningState<Long, long[], Long>> minBufferedTsSpec; | |
RateLimitFn( | |
long maxElements, | |
long maxBytes, | |
@Nullable SerializableFunction<InputT, Long> weigher, | |
Duration intervalDuration, | |
Coder<InputT> inputValueCoder) { | |
this.maxElements = maxElements; | |
this.maxBytes = maxBytes; | |
this.weigher = weigher; | |
this.intervalDuration = intervalDuration; | |
this.bufferSpec = StateSpecs.bag(TimestampedValue.TimestampedValueCoder.of(inputValueCoder)); | |
Combine.BinaryCombineLongFn sumCombineFn = | |
new Combine.BinaryCombineLongFn() { | |
@Override | |
public long identity() { | |
return 0L; | |
} | |
@Override | |
public long apply(long left, long right) { | |
return left + right; | |
} | |
}; | |
Combine.BinaryCombineLongFn minCombineFn = | |
new Combine.BinaryCombineLongFn() { | |
@Override | |
public long identity() { | |
return BoundedWindow.TIMESTAMP_MAX_VALUE.getMillis(); | |
} | |
@Override | |
public long apply(long left, long right) { | |
return Math.min(left, right); | |
} | |
}; | |
this.intervalSizeSpec = StateSpecs.combining(sumCombineFn); | |
this.intervalSizeBytesSpec = StateSpecs.combining(sumCombineFn); | |
this.timerTsSpec = StateSpecs.value(); | |
this.minBufferedTsSpec = StateSpecs.combining(minCombineFn); | |
this.bufferSizeSpec = StateSpecs.combining(sumCombineFn); | |
} | |
@ProcessElement | |
public void processElement( | |
@TimerId(END_OF_INTERVAL_ID) Timer intervalTimer, | |
@StateId(BUFFER_ID) BagState<TimestampedValue<InputT>> buffer, | |
@StateId(NUM_ELEMENTS_BUFFERED_ID) | |
CombiningState<Long, long[], Long> numBuffered, | |
@AlwaysFetched @StateId(NUM_ELEMENTS_IN_INTERVAL_ID) | |
CombiningState<Long, long[], Long> numOutputInInterval, | |
@AlwaysFetched @StateId(NUM_BYTES_IN_INTERVAL_ID) | |
CombiningState<Long, long[], Long> numBytesOutputInInterval, | |
@StateId(TIMER_TIMESTAMP) ValueState<Long> timerTs, | |
@StateId(MIN_OBSERVED_TS) CombiningState<Long, long[], Long> minBufferedTs, | |
@Element KV<K, InputT> element, | |
@Timestamp Instant elementTs, | |
BoundedWindow window, | |
OutputReceiver<KV<K, InputT>> receiver) { | |
outputAndObserve( | |
intervalTimer, | |
buffer, | |
numBuffered, | |
numOutputInInterval, | |
numBytesOutputInInterval, | |
timerTs, | |
minBufferedTs, | |
element, | |
elementTs, | |
receiver); | |
} | |
@OnTimer(END_OF_INTERVAL_ID) | |
public void onIntervalTimer( | |
OutputReceiver<KV<K, InputT>> receiver, | |
@Timestamp Instant timestamp, | |
@Key K key, | |
@AlwaysFetched @StateId(BUFFER_ID) BagState<TimestampedValue<InputT>> buffer, | |
@StateId(NUM_ELEMENTS_BUFFERED_ID) | |
CombiningState<Long, long[], Long> numBuffered, | |
@StateId(NUM_ELEMENTS_IN_INTERVAL_ID) | |
CombiningState<Long, long[], Long> numOutputInInterval, | |
@StateId(NUM_BYTES_IN_INTERVAL_ID) | |
CombiningState<Long, long[], Long> numBytesOutputInInterval, | |
@StateId(TIMER_TIMESTAMP) ValueState<Long> timerTs, | |
@StateId(MIN_OBSERVED_TS) CombiningState<Long, long[], Long> minBufferedTs, | |
@TimerId(END_OF_INTERVAL_ID) Timer intervalTimer) { | |
LOG.debug( | |
"*** END OF INTERVAL *** for timer timestamp {} with buffering duration {}", | |
timestamp.toInstant(), | |
intervalDuration.getMillis()); | |
Iterable<TimestampedValue<InputT>> batch = buffer.read(); | |
clearState(buffer, numBuffered, numOutputInInterval, numBytesOutputInInterval, minBufferedTs); | |
flushBatch( | |
receiver, | |
key, | |
intervalTimer, | |
batch, | |
buffer, | |
numBuffered, | |
numOutputInInterval, | |
numBytesOutputInInterval, | |
timerTs, | |
minBufferedTs); | |
} | |
@OnWindowExpiration | |
public void onWindowExpiration( | |
OutputReceiver<KV<K, InputT>> receiver, | |
@Key K key, | |
@AlwaysFetched @StateId(BUFFER_ID) BagState<TimestampedValue<InputT>> buffer, | |
@AlwaysFetched @StateId(NUM_ELEMENTS_BUFFERED_ID) | |
CombiningState<Long, long[], Long> numBuffered, | |
@StateId(NUM_ELEMENTS_IN_INTERVAL_ID) | |
CombiningState<Long, long[], Long> numOutputInInterval, | |
@StateId(NUM_BYTES_IN_INTERVAL_ID) | |
CombiningState<Long, long[], Long> numBytesOutputInInterval, | |
@StateId(TIMER_TIMESTAMP) ValueState<Long> timerTs, | |
@StateId(MIN_OBSERVED_TS) CombiningState<Long, long[], Long> minBufferedTs) { | |
// This code will only be invoked when pipelines are cancelled and watermark is | |
// fast-forward to max value | |
LOG.debug( | |
"*** END OF WINDOW *** for timer timestamp {} with buffering duration {}", | |
Instant.ofEpochMilli(timerTs.read()), | |
intervalDuration.getMillis()); | |
Instant start = null; | |
while (numBuffered.read() > 0) { | |
// Since we have no way of knowing if the limit was hit just prior, wait one | |
// interval before outputting | |
Instant now = Instant.now(); | |
long delayMs = | |
intervalDuration.getMillis() | |
- (now.minus(MoreObjects.firstNonNull(start, now).getMillis()).getMillis()); | |
Uninterruptibles.sleepUninterruptibly(Math.max(0L, delayMs), TimeUnit.MILLISECONDS); | |
start = Instant.now(); | |
Iterable<TimestampedValue<InputT>> batch = buffer.read(); | |
clearState( | |
buffer, numBuffered, numOutputInInterval, numBytesOutputInInterval, minBufferedTs); | |
flushBatch( | |
receiver, | |
key, | |
null, | |
batch, | |
buffer, | |
numBuffered, | |
numOutputInInterval, | |
numBytesOutputInInterval, | |
timerTs, | |
minBufferedTs); | |
} | |
} | |
private void flushBatch( | |
OutputReceiver<KV<K, InputT>> receiver, | |
K key, | |
Timer intervalTimer, | |
Iterable<TimestampedValue<InputT>> batch, | |
BagState<TimestampedValue<InputT>> elementBuffer, | |
CombiningState<Long, long[], Long> numBuffered, | |
CombiningState<Long, long[], Long> numOutputInInterval, | |
CombiningState<Long, long[], Long> numBytesOutputInInterval, | |
ValueState<Long> timerTs, | |
CombiningState<Long, long[], Long> minBufferedTs) { | |
// When the timer fires, elementBuffer state might be empty | |
if (!Iterables.isEmpty(batch)) { | |
for (TimestampedValue<InputT> val : batch) { | |
outputAndObserve( | |
intervalTimer, | |
elementBuffer, | |
numBuffered, | |
numOutputInInterval, | |
numBytesOutputInInterval, | |
timerTs, | |
minBufferedTs, | |
KV.of(key, val.getValue()), | |
val.getTimestamp(), | |
receiver); | |
} | |
} | |
} | |
private void outputAndObserve( | |
@Nullable Timer intervalTimer, | |
BagState<TimestampedValue<InputT>> elementBuffer, | |
CombiningState<Long, long[], Long> numBuffered, | |
CombiningState<Long, long[], Long> numOutputInInterval, | |
CombiningState<Long, long[], Long> numBytesOutputInInterval, | |
ValueState<Long> timerTs, | |
CombiningState<Long, long[], Long> minBufferedTs, | |
KV<K, InputT> element, | |
Instant elementTs, | |
OutputReceiver<KV<K, InputT>> receiver) { | |
numOutputInInterval.readLater(); | |
numBytesOutputInInterval.readLater(); | |
long numElementsSoFar = numOutputInInterval.read(); | |
if (numElementsSoFar == 0L && intervalTimer != null) { | |
setTimerFiringTime(intervalTimer, timerTs); | |
} | |
if (numElementsSoFar >= maxElements || numBytesOutputInInterval.read() >= maxBytes) { | |
// We can't output, we've hit the limit for the interval already. Buffer this | |
// element into the bag and move on since it will be flushed later when there is | |
// capacity under the rate limit | |
elementBuffer.add(TimestampedValue.of(element.getValue(), elementTs)); | |
numBuffered.add(1L); | |
long oldOutputTs = | |
MoreObjects.firstNonNull( | |
minBufferedTs.read(), BoundedWindow.TIMESTAMP_MAX_VALUE.getMillis()); | |
// Blind add is supported with combiningState | |
minBufferedTs.add(elementTs.getMillis()); | |
if (minBufferedTs.read() != oldOutputTs && intervalTimer != null) { | |
// we need to update the timer's output timestamp | |
updateTimerOutputTimestamp(intervalTimer, timerTs, minBufferedTs); | |
} | |
return; | |
} | |
receiver.output(element); | |
LOG.debug("*** OUTPUT DOWNSTREAM *** {} for timer timestamp {}", element, Instant.now()); | |
// Blind add is supported with combiningState | |
numOutputInInterval.add(1L); | |
if (weigher != null) { | |
// Blind add is supported with combiningState | |
numBytesOutputInInterval.add(weigher.apply(element.getValue())); | |
} | |
} | |
private void updateTimerOutputTimestamp( | |
Timer intervalTimer, | |
ValueState<Long> timerTs, | |
CombiningState<Long, long[], Long> minBufferedTs) { | |
// Update the output timestamp, but not the firing time | |
intervalTimer | |
.withOutputTimestamp(Instant.ofEpochMilli(minBufferedTs.read())) | |
.set(Instant.ofEpochMilli(timerTs.read())); | |
} | |
private void setTimerFiringTime(Timer intervalTimer, ValueState<Long> timerTs) { | |
long triggerTs = | |
Math.min( | |
Instant.now().plus(intervalDuration.getMillis()).getMillis(), | |
BoundedWindow.TIMESTAMP_MAX_VALUE.minus(Duration.standardDays(1)).getMillis()); | |
LOG.debug("Setting timer to trigger at {}", Instant.ofEpochMilli(triggerTs)); | |
intervalTimer.withNoOutputTimestamp().set(Instant.ofEpochMilli(triggerTs)); | |
timerTs.write(triggerTs); | |
} | |
private void clearState(State... specs) { | |
for (State s : specs) { | |
s.clear(); | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment