Last active
April 15, 2020 09:43
-
-
Save matejdro/a9c838bf0066595fb52b4b8816f49252 to your computer and use it in GitHub Desktop.
Multicast flow - https://github.com/Kotlin/kotlinx.coroutines/issues/1261
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import kotlinx.coroutines.GlobalScope | |
import kotlinx.coroutines.InternalCoroutinesApi | |
import kotlinx.coroutines.Job | |
import kotlinx.coroutines.channels.Channel | |
import kotlinx.coroutines.channels.ReceiveChannel | |
import kotlinx.coroutines.channels.SendChannel | |
import kotlinx.coroutines.channels.ValueOrClosed | |
import kotlinx.coroutines.flow.Flow | |
import kotlinx.coroutines.flow.consumeAsFlow | |
import kotlinx.coroutines.flow.emitAll | |
import kotlinx.coroutines.flow.flow | |
import kotlinx.coroutines.flow.produceIn | |
import kotlinx.coroutines.isActive | |
import kotlinx.coroutines.launch | |
import kotlinx.coroutines.selects.select | |
import kotlinx.coroutines.sync.Mutex | |
import kotlinx.coroutines.sync.withLock | |
@UseExperimental(InternalCoroutinesApi::class) | |
@Suppress("EXPERIMENTAL_API_USAGE") | |
class MulticastFlow<T>( | |
private val original: Flow<T>, | |
private val conflate: Boolean, | |
private val debounceMs: Int | |
) { | |
private val mutex = Mutex() | |
private val collectors = ArrayList<SendChannel<T>>() | |
private var lastValue: Result<T>? = null | |
private var nextDebounceTarget: Long = -1L | |
private var actor = Channel<MulticastActorAction<T>>(Channel.BUFFERED) | |
private var flowChannel: ReceiveChannel<T>? = null | |
private var multicastActorJob: Job? = null | |
private suspend fun ensureActorActive() { | |
if (multicastActorJob?.isActive != true) { | |
mutex.withLock { | |
if (multicastActorJob?.isActive != true) { | |
startFlowActor() | |
} | |
} | |
} | |
} | |
private fun startFlowActor() { | |
// Create new channel to clear buffer of the previous channel | |
actor = Channel(Channel.BUFFERED) | |
multicastActorJob = GlobalScope.launch { | |
while (isActive) { | |
val currentFlowChannel = flowChannel | |
select<Unit> { | |
actor.onReceive { action -> | |
onActorAction(action) | |
} | |
@Suppress("IfThenToSafeAccess") | |
if (currentFlowChannel != null) { | |
currentFlowChannel.onReceiveOrClosed { valueOrClosed -> | |
onOriginalFlowData(valueOrClosed) | |
} | |
} | |
if (nextDebounceTarget >= 0) { | |
onTimeout(nextDebounceTarget - System.currentTimeMillis()) { | |
closeOriginalFlow() | |
nextDebounceTarget = -1 | |
} | |
} | |
} | |
} | |
} | |
} | |
private suspend fun onActorAction(action: MulticastActorAction<T>) { | |
when (action) { | |
is MulticastActorAction.AddCollector -> { | |
collectors.add(action.channel) | |
if (flowChannel == null) { | |
flowChannel = original.produceIn(GlobalScope) | |
} | |
val lastValue = lastValue | |
if (lastValue != null) { | |
action.channel.send(lastValue.getOrThrow()) | |
} | |
nextDebounceTarget = -1 | |
} | |
is MulticastActorAction.RemoveCollector -> { | |
val collectorIndex = collectors.indexOf(action.channel) | |
if (collectorIndex >= 0) { | |
val removedCollector = collectors.removeAt(collectorIndex) | |
removedCollector.close() | |
} | |
if (collectors.isEmpty()) { | |
if (debounceMs > 0) { | |
nextDebounceTarget = System.currentTimeMillis() + debounceMs | |
} else { | |
closeOriginalFlow() | |
} | |
} | |
} | |
} | |
} | |
private fun closeOriginalFlow() { | |
lastValue = null | |
flowChannel?.cancel() | |
flowChannel = null | |
multicastActorJob?.cancel() | |
} | |
private suspend fun onOriginalFlowData(valueOrClosed: ValueOrClosed<T>) { | |
if (valueOrClosed.isClosed) { | |
collectors.forEach { it.close(valueOrClosed.closeCause) } | |
collectors.clear() | |
closeOriginalFlow() | |
} else { | |
collectors.forEach { | |
try { | |
if (conflate) { | |
lastValue = Result.success(valueOrClosed.value) | |
} | |
it.send(valueOrClosed.value) | |
} catch (e: Exception) { | |
// Ignore downstream exceptions | |
} | |
} | |
} | |
} | |
val multicastedFlow = flow { | |
val channel = Channel<T>() | |
try { | |
ensureActorActive() | |
actor.send(MulticastActorAction.AddCollector(channel)) | |
emitAll(channel.consumeAsFlow()) | |
} finally { | |
actor.send(MulticastActorAction.RemoveCollector(channel)) | |
} | |
} | |
private sealed class MulticastActorAction<T> { | |
class AddCollector<T>(val channel: SendChannel<T>) : MulticastActorAction<T>() | |
class RemoveCollector<T>(val channel: SendChannel<T>) : MulticastActorAction<T>() | |
} | |
} | |
/** | |
* Allow multiple collectors to collect same instance of this flow | |
* | |
* @param conflate Whether new collector should receive last collected value | |
* @param debounceMs Number of milliseconds to wait after last collector closes | |
* before closing original flow. Set to 0 to disable. | |
*/ | |
fun <T> Flow<T>.share( | |
conflate: Boolean = false, | |
debounceMs: Int = 0 | |
): Flow<T> { | |
return MulticastFlow(this, conflate, debounceMs).multicastedFlow | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import kotlinx.coroutines.CancellationException | |
import kotlinx.coroutines.CompletableDeferred | |
import kotlinx.coroutines.Job | |
import kotlinx.coroutines.cancel | |
import kotlinx.coroutines.delay | |
import kotlinx.coroutines.flow.Flow | |
import kotlinx.coroutines.flow.collect | |
import kotlinx.coroutines.flow.flow | |
import kotlinx.coroutines.joinAll | |
import kotlinx.coroutines.launch | |
import kotlinx.coroutines.runBlocking | |
import kotlinx.coroutines.suspendCancellableCoroutine | |
import kotlinx.coroutines.sync.Semaphore | |
import kotlinx.coroutines.withTimeoutOrNull | |
import org.assertj.core.api.Assertions.assertThat | |
import org.junit.Ignore | |
import org.junit.Test | |
import java.lang.IllegalStateException | |
import java.lang.ref.PhantomReference | |
import java.lang.ref.WeakReference | |
import java.util.concurrent.atomic.AtomicBoolean | |
import java.util.concurrent.atomic.AtomicInteger | |
class MulticastFlowTest { | |
@Test | |
fun `only launch flow once`() = runBlocking<Unit> { | |
val numLaunches = AtomicInteger(0) | |
val flow = flow { | |
numLaunches.incrementAndGet() | |
emit("A") | |
delay(5) | |
}.share() | |
val tasks = ArrayList<Job>() | |
tasks += launch { | |
flow.collect {} | |
} | |
tasks += launch { | |
flow.collect {} | |
} | |
tasks.joinAll() | |
assertThat(numLaunches.get()).isEqualTo(1) | |
} | |
@Test | |
fun `receive items on all collectors`() = runBlocking<Unit> { | |
val itemsA = ArrayList<String>() | |
val itemsB = ArrayList<String>() | |
val flow = flow { | |
delay(5) | |
emit("A") | |
emit("B") | |
emit("C") | |
}.share() | |
val tasks = ArrayList<Job>() | |
tasks += launch { | |
flow.collect { | |
itemsA += it | |
} | |
} | |
tasks += launch { | |
flow.collect { | |
itemsB += it | |
} | |
} | |
tasks.joinAll() | |
assertThat(itemsA).containsExactly("A", "B", "C") | |
assertThat(itemsB).containsExactly("A", "B", "C") | |
} | |
@Test | |
fun `close flow after all collectors close`() = runBlocking<Unit> { | |
val closedCompletable = CompletableDeferred<Unit>() | |
val flow = flow { | |
try { | |
delay(10) | |
emit("A") | |
suspendCancellableCoroutine<Unit> { } | |
} finally { | |
closedCompletable.complete(Unit) | |
} | |
}.share() | |
val tasks = ArrayList<Job>() | |
tasks += launch { | |
flow.collect { | |
coroutineContext.cancel() | |
} | |
} | |
tasks += launch { | |
flow.collect { | |
coroutineContext.cancel() | |
} | |
} | |
tasks.joinAll() | |
withTimeoutOrNull(5_000) { closedCompletable.join() } ?: error("Flow not closed") | |
} | |
@Test | |
fun `do not crash the whole flow if one collector throws exception`() = runBlocking<Unit> { | |
val receivedItems = ArrayList<String>() | |
val flow = flow { | |
delay(5) | |
emit("A") | |
emit("B") | |
emit("C") | |
}.share() | |
val tasks = ArrayList<Job>() | |
tasks += launch { | |
flow.collect { | |
delay(10) | |
receivedItems += it | |
} | |
} | |
tasks += launch { | |
try { | |
flow.collect { | |
throw IllegalStateException("Test") | |
} | |
} catch (e: Exception) { | |
} | |
} | |
tasks.joinAll() | |
assertThat(receivedItems).containsExactly("A", "B", "C") | |
} | |
@Test | |
fun `receive exceptions on all producers`() = runBlocking<Unit> { | |
val receivedA = AtomicBoolean(false) | |
val receivedB = AtomicBoolean(false) | |
val flow = flow<String> { | |
delay(5) | |
throw CloneNotSupportedException() | |
}.share() | |
val tasks = ArrayList<Job>() | |
tasks += launch { | |
try { | |
flow.collect {} | |
} catch (e: CloneNotSupportedException) { | |
receivedA.set(true) | |
} | |
} | |
tasks += launch { | |
try { | |
flow.collect {} | |
} catch (e: CloneNotSupportedException) { | |
receivedB.set(true) | |
} | |
} | |
tasks.joinAll() | |
assertThat(receivedA.get()).isTrue() | |
assertThat(receivedB.get()).isTrue() | |
} | |
@Test | |
fun `receive last item when new collector starts collecting existing flow`() = runBlocking<Unit> { | |
val itemsA = ArrayList<String>() | |
val itemsB = ArrayList<String>() | |
val flow = flow { | |
delay(5) | |
emit("A") | |
emit("B") | |
emit("C") | |
delay(15) | |
}.share(conflate = true) | |
val tasks = ArrayList<Job>() | |
tasks += launch { | |
flow.collect { | |
itemsA += it | |
} | |
} | |
tasks += launch { | |
delay(10) | |
flow.collect { | |
itemsB += it | |
} | |
} | |
tasks.joinAll() | |
assertThat(itemsA).containsExactly("A", "B", "C") | |
assertThat(itemsB).containsExactly("C") | |
} | |
@Test | |
fun `do not receive conflated last item when there were no active collectors`() = runBlocking<Unit> { | |
val itemsA = ArrayList<String>() | |
val itemsB = ArrayList<String>() | |
val flow = flow { | |
delay(5) | |
emit("A") | |
emit("B") | |
emit("C") | |
delay(15) | |
}.share(conflate = true) | |
flow.collect { | |
itemsA += it | |
} | |
flow.collect { | |
itemsB += it | |
} | |
assertThat(itemsA).containsExactly("A", "B", "C") | |
assertThat(itemsB).containsExactly("A", "B", "C") | |
} | |
@Test | |
fun `do not receive last item when conflate is disabled`() = runBlocking<Unit> { | |
val itemsA = ArrayList<String>() | |
val itemsB = ArrayList<String>() | |
val semaphore = Semaphore(1) | |
semaphore.acquire() | |
val flow = flow { | |
delay(5) | |
emit("A") | |
emit("B") | |
emit("C") | |
delay(15) | |
semaphore.release() | |
delay(15) | |
}.share(conflate = false) | |
val tasks = ArrayList<Job>() | |
tasks += launch { | |
flow.collect { | |
itemsA += it | |
} | |
} | |
tasks += launch { | |
semaphore.acquire() | |
flow.collect { | |
itemsB += it | |
} | |
} | |
tasks.joinAll() | |
assertThat(itemsA).containsExactly("A", "B", "C") | |
assertThat(itemsB).isEmpty() | |
} | |
@Test | |
fun `do not close flow within debounce period`() = runBlocking<Unit> { | |
val numLaunches = AtomicInteger(0) | |
val flow = flow { | |
numLaunches.incrementAndGet() | |
var counter = 0 | |
while (true) { | |
emit(counter++) | |
delay(5) | |
} | |
}.share(debounceMs = 500) | |
val tasks = ArrayList<Job>() | |
tasks += launch { | |
withTimeoutOrNull(10) { | |
flow.collect {} | |
} | |
} | |
delay(20) | |
tasks += launch { | |
withTimeoutOrNull(10) { | |
flow.collect {} | |
} | |
} | |
tasks.joinAll() | |
assertThat(numLaunches.get()).isEqualTo(1) | |
} | |
@Test | |
fun `close flow after debounce period`() = runBlocking<Unit> { | |
val numLaunches = AtomicInteger(0) | |
val flow = flow { | |
numLaunches.incrementAndGet() | |
emit("A") | |
delay(999) | |
}.share(debounceMs = 20) | |
val tasks = ArrayList<Job>() | |
tasks += launch { | |
withTimeoutOrNull(10) { | |
flow.collect {} | |
} | |
} | |
delay(100) | |
tasks += launch { | |
withTimeoutOrNull(10) { | |
flow.collect {} | |
} | |
} | |
tasks.joinAll() | |
tasks.joinAll() | |
assertThat(numLaunches.get()).isEqualTo(2) | |
} | |
@Test | |
fun `do not leak actor to GlobalScope`() = runBlocking<Unit> { | |
val flow = createWeakFlow() | |
System.gc() | |
delay(100) | |
System.gc() | |
assertThat(flow.get()).isNull() | |
} | |
private suspend fun createWeakFlow(): WeakReference<Flow<String>> { | |
val flow = flow { | |
emit("A") | |
delay(999) | |
}.share(debounceMs = 50) | |
flow.collect {} | |
return WeakReference( | |
flow | |
) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment