Skip to content

Instantly share code, notes, and snippets.

@RBusarow
Last active December 8, 2019 05:59
Show Gist options
  • Save RBusarow/d2c4c164d965a52c7cd4ec0e1cacd955 to your computer and use it in GitHub Desktop.
Save RBusarow/d2c4c164d965a52c7cd4ec0e1cacd955 to your computer and use it in GitHub Desktop.
A non-blocking, non-suspending, "cold" implementation of cache/replay functionality for coroutines Flow.
import kotlinx.coroutines.flow.*
/**
* A "replay" flow which will record the last [size] collected values.
*
* When a collector begins collecting after values have already been recorded,
* those values will be collected *before* values from the [sourceFlow] are collected.
*
* example:
* ```
* val ints = flowOf(1, 2, 3, 4).cache(2) // cache the last 2 values
*
* ints.take(4).collect { ... } // 4 values are emitted, but also recorded. The last 2 remain.
*
* ints.collect { ... } // collects [3, 4, 1, 2, 3, 4]
*/
fun <T> Flow<T>.cache(size: Int): Flow<T> = CachedFlow(this, size)
/**
* A "replay" flow which will record the last [size] collected values.
*
* When a collector begins collecting after values have already been recorded,
* those values will be collected *before* values from the [sourceFlow] are collected.
*
* example:
* ```
* val ints = flowOf(1, 2, 3, 4).cache(2) // cache the last 2 values
*
* ints.take(4).collect { ... } // 4 values are emitted, but also recorded. The last 2 remain.
*
* ints.collect { ... } // collects [3, 4, 1, 2, 3, 4]
*/
internal class CachedFlow<T>(
private val sourceFlow: Flow<T>,
private val size: Int
) : Flow<T> {
init {
require(size > 0) { "size parameter must be greater than 0, but was $size" }
}
internal val cache = CircularArray<T>(size)
override suspend fun collect(
collector: FlowCollector<T>
) = collector.emitAll(createFlow())
private fun createFlow(): Flow<T> = sourceFlow.onEach { value ->
/*
* While flowing, also record all values in the cache.
*/
cache.add(value)
}.onStart {
/*
* Before emitting any values in sourceFlow,
* emit any cached values starting with the oldest.
*/
cache.forEach { emit(it) }
}
}
/**
* CircularArray implementation which will hold the latest of up to [size] elements.
*
* After the cache has been filled, all further additions will overwrite the oldest value.
*/
internal class CircularArray<T>(size: Int) {
private val array: Array<Any?> = arrayOfNulls(size)
private var count: Int = 0
private var tail: Int = -1
/**
* Adds [item] to the [CircularArray].
*
* If the [CircularArray] has not yet been filled,
* [item] will simply be added to the next available slot.
*
* If the [CircularArray] has already been filled,
* [item] will replace the oldest item already in the array.
*
* example:
* ```
* val ca = CircularArray<T>(3)
*
* ca.add(0) // ca contents : [0, null, null]
* ca.add(1) // ca contents : [0, 1, null]
* ca.add(2) // ca contents : [0, 1, 2]
* // overwrite the oldest value
* ca.add(3) // ca contents : [3, 1, 2]
* ```
*/
fun add(item: T) {
tail = (tail + 1) % array.size
array[tail] = item
if (count < array.size) count++
}
/**
* Iterates over the [CircularArray], performing [block] on each item.
*
* Order is always first-in-first-out, with the oldest item being used first.
*
* example:
* ```
* val ca = CircularArray<Int>(3)
*
* ca.add(0) // ca contents : [0, null, null]
* ca.add(1) // ca contents : [0, 1, null]
* ca.add(2) // ca contents : [0, 1, 2]
* // overwrite the oldest value
* ca.add(3) // ca contents : [3, 1, 2]
*
* ca.forEach { ... } // collects [1, 2, 3]
* ```
*/
@Suppress("UNCHECKED_CAST")
internal inline fun forEach(block: (T) -> Unit) {
val arraySnapshot = array.copyOf()
val tailSnapshot = tail
if (count < arraySnapshot.size) {
/*
* if the array hasn't yet looped back on itself,
* just invoke block all items in order from 0 to tailSnapshot
*/
repeat(tailSnapshot + 1) { index ->
block(arraySnapshot[index] as T)
}
} else {
/*
* the array is full and we are now iterating from
* the oldest index (tail + 1) until the end of the array,
* then starting from the beginning of the array
* and iterating until we reach the tail (which is the index of the newest element)
*/
val oldestIndex = tailSnapshot + 1
for (index in oldestIndex until count + oldestIndex) {
block(arraySnapshot[(index) % count] as T)
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment