Skip to content

Instantly share code, notes, and snippets.

@matejdro
Last active April 15, 2020 09:43
Show Gist options
  • Save matejdro/a9c838bf0066595fb52b4b8816f49252 to your computer and use it in GitHub Desktop.
Save matejdro/a9c838bf0066595fb52b4b8816f49252 to your computer and use it in GitHub Desktop.
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.InternalCoroutinesApi
import kotlinx.coroutines.Job
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.channels.SendChannel
import kotlinx.coroutines.channels.ValueOrClosed
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.consumeAsFlow
import kotlinx.coroutines.flow.emitAll
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.produceIn
import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch
import kotlinx.coroutines.selects.select
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
@UseExperimental(InternalCoroutinesApi::class)
@Suppress("EXPERIMENTAL_API_USAGE")
class MulticastFlow<T>(
private val original: Flow<T>,
private val conflate: Boolean,
private val debounceMs: Int
) {
private val mutex = Mutex()
private val collectors = ArrayList<SendChannel<T>>()
private var lastValue: Result<T>? = null
private var nextDebounceTarget: Long = -1L
private var actor = Channel<MulticastActorAction<T>>(Channel.BUFFERED)
private var flowChannel: ReceiveChannel<T>? = null
private var multicastActorJob: Job? = null
private suspend fun ensureActorActive() {
if (multicastActorJob?.isActive != true) {
mutex.withLock {
if (multicastActorJob?.isActive != true) {
startFlowActor()
}
}
}
}
private fun startFlowActor() {
// Create new channel to clear buffer of the previous channel
actor = Channel(Channel.BUFFERED)
multicastActorJob = GlobalScope.launch {
while (isActive) {
val currentFlowChannel = flowChannel
select<Unit> {
actor.onReceive { action ->
onActorAction(action)
}
@Suppress("IfThenToSafeAccess")
if (currentFlowChannel != null) {
currentFlowChannel.onReceiveOrClosed { valueOrClosed ->
onOriginalFlowData(valueOrClosed)
}
}
if (nextDebounceTarget >= 0) {
onTimeout(nextDebounceTarget - System.currentTimeMillis()) {
closeOriginalFlow()
nextDebounceTarget = -1
}
}
}
}
}
}
private suspend fun onActorAction(action: MulticastActorAction<T>) {
when (action) {
is MulticastActorAction.AddCollector -> {
collectors.add(action.channel)
if (flowChannel == null) {
flowChannel = original.produceIn(GlobalScope)
}
val lastValue = lastValue
if (lastValue != null) {
action.channel.send(lastValue.getOrThrow())
}
nextDebounceTarget = -1
}
is MulticastActorAction.RemoveCollector -> {
val collectorIndex = collectors.indexOf(action.channel)
if (collectorIndex >= 0) {
val removedCollector = collectors.removeAt(collectorIndex)
removedCollector.close()
}
if (collectors.isEmpty()) {
if (debounceMs > 0) {
nextDebounceTarget = System.currentTimeMillis() + debounceMs
} else {
closeOriginalFlow()
}
}
}
}
}
private fun closeOriginalFlow() {
lastValue = null
flowChannel?.cancel()
flowChannel = null
multicastActorJob?.cancel()
}
private suspend fun onOriginalFlowData(valueOrClosed: ValueOrClosed<T>) {
if (valueOrClosed.isClosed) {
collectors.forEach { it.close(valueOrClosed.closeCause) }
collectors.clear()
closeOriginalFlow()
} else {
collectors.forEach {
try {
if (conflate) {
lastValue = Result.success(valueOrClosed.value)
}
it.send(valueOrClosed.value)
} catch (e: Exception) {
// Ignore downstream exceptions
}
}
}
}
val multicastedFlow = flow {
val channel = Channel<T>()
try {
ensureActorActive()
actor.send(MulticastActorAction.AddCollector(channel))
emitAll(channel.consumeAsFlow())
} finally {
actor.send(MulticastActorAction.RemoveCollector(channel))
}
}
private sealed class MulticastActorAction<T> {
class AddCollector<T>(val channel: SendChannel<T>) : MulticastActorAction<T>()
class RemoveCollector<T>(val channel: SendChannel<T>) : MulticastActorAction<T>()
}
}
/**
* Allow multiple collectors to collect same instance of this flow
*
* @param conflate Whether new collector should receive last collected value
* @param debounceMs Number of milliseconds to wait after last collector closes
* before closing original flow. Set to 0 to disable.
*/
fun <T> Flow<T>.share(
conflate: Boolean = false,
debounceMs: Int = 0
): Flow<T> {
return MulticastFlow(this, conflate, debounceMs).multicastedFlow
}
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.Job
import kotlinx.coroutines.cancel
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.joinAll
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.suspendCancellableCoroutine
import kotlinx.coroutines.sync.Semaphore
import kotlinx.coroutines.withTimeoutOrNull
import org.assertj.core.api.Assertions.assertThat
import org.junit.Ignore
import org.junit.Test
import java.lang.IllegalStateException
import java.lang.ref.PhantomReference
import java.lang.ref.WeakReference
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicInteger
class MulticastFlowTest {
@Test
fun `only launch flow once`() = runBlocking<Unit> {
val numLaunches = AtomicInteger(0)
val flow = flow {
numLaunches.incrementAndGet()
emit("A")
delay(5)
}.share()
val tasks = ArrayList<Job>()
tasks += launch {
flow.collect {}
}
tasks += launch {
flow.collect {}
}
tasks.joinAll()
assertThat(numLaunches.get()).isEqualTo(1)
}
@Test
fun `receive items on all collectors`() = runBlocking<Unit> {
val itemsA = ArrayList<String>()
val itemsB = ArrayList<String>()
val flow = flow {
delay(5)
emit("A")
emit("B")
emit("C")
}.share()
val tasks = ArrayList<Job>()
tasks += launch {
flow.collect {
itemsA += it
}
}
tasks += launch {
flow.collect {
itemsB += it
}
}
tasks.joinAll()
assertThat(itemsA).containsExactly("A", "B", "C")
assertThat(itemsB).containsExactly("A", "B", "C")
}
@Test
fun `close flow after all collectors close`() = runBlocking<Unit> {
val closedCompletable = CompletableDeferred<Unit>()
val flow = flow {
try {
delay(10)
emit("A")
suspendCancellableCoroutine<Unit> { }
} finally {
closedCompletable.complete(Unit)
}
}.share()
val tasks = ArrayList<Job>()
tasks += launch {
flow.collect {
coroutineContext.cancel()
}
}
tasks += launch {
flow.collect {
coroutineContext.cancel()
}
}
tasks.joinAll()
withTimeoutOrNull(5_000) { closedCompletable.join() } ?: error("Flow not closed")
}
@Test
fun `do not crash the whole flow if one collector throws exception`() = runBlocking<Unit> {
val receivedItems = ArrayList<String>()
val flow = flow {
delay(5)
emit("A")
emit("B")
emit("C")
}.share()
val tasks = ArrayList<Job>()
tasks += launch {
flow.collect {
delay(10)
receivedItems += it
}
}
tasks += launch {
try {
flow.collect {
throw IllegalStateException("Test")
}
} catch (e: Exception) {
}
}
tasks.joinAll()
assertThat(receivedItems).containsExactly("A", "B", "C")
}
@Test
fun `receive exceptions on all producers`() = runBlocking<Unit> {
val receivedA = AtomicBoolean(false)
val receivedB = AtomicBoolean(false)
val flow = flow<String> {
delay(5)
throw CloneNotSupportedException()
}.share()
val tasks = ArrayList<Job>()
tasks += launch {
try {
flow.collect {}
} catch (e: CloneNotSupportedException) {
receivedA.set(true)
}
}
tasks += launch {
try {
flow.collect {}
} catch (e: CloneNotSupportedException) {
receivedB.set(true)
}
}
tasks.joinAll()
assertThat(receivedA.get()).isTrue()
assertThat(receivedB.get()).isTrue()
}
@Test
fun `receive last item when new collector starts collecting existing flow`() = runBlocking<Unit> {
val itemsA = ArrayList<String>()
val itemsB = ArrayList<String>()
val flow = flow {
delay(5)
emit("A")
emit("B")
emit("C")
delay(15)
}.share(conflate = true)
val tasks = ArrayList<Job>()
tasks += launch {
flow.collect {
itemsA += it
}
}
tasks += launch {
delay(10)
flow.collect {
itemsB += it
}
}
tasks.joinAll()
assertThat(itemsA).containsExactly("A", "B", "C")
assertThat(itemsB).containsExactly("C")
}
@Test
fun `do not receive conflated last item when there were no active collectors`() = runBlocking<Unit> {
val itemsA = ArrayList<String>()
val itemsB = ArrayList<String>()
val flow = flow {
delay(5)
emit("A")
emit("B")
emit("C")
delay(15)
}.share(conflate = true)
flow.collect {
itemsA += it
}
flow.collect {
itemsB += it
}
assertThat(itemsA).containsExactly("A", "B", "C")
assertThat(itemsB).containsExactly("A", "B", "C")
}
@Test
fun `do not receive last item when conflate is disabled`() = runBlocking<Unit> {
val itemsA = ArrayList<String>()
val itemsB = ArrayList<String>()
val semaphore = Semaphore(1)
semaphore.acquire()
val flow = flow {
delay(5)
emit("A")
emit("B")
emit("C")
delay(15)
semaphore.release()
delay(15)
}.share(conflate = false)
val tasks = ArrayList<Job>()
tasks += launch {
flow.collect {
itemsA += it
}
}
tasks += launch {
semaphore.acquire()
flow.collect {
itemsB += it
}
}
tasks.joinAll()
assertThat(itemsA).containsExactly("A", "B", "C")
assertThat(itemsB).isEmpty()
}
@Test
fun `do not close flow within debounce period`() = runBlocking<Unit> {
val numLaunches = AtomicInteger(0)
val flow = flow {
numLaunches.incrementAndGet()
var counter = 0
while (true) {
emit(counter++)
delay(5)
}
}.share(debounceMs = 500)
val tasks = ArrayList<Job>()
tasks += launch {
withTimeoutOrNull(10) {
flow.collect {}
}
}
delay(20)
tasks += launch {
withTimeoutOrNull(10) {
flow.collect {}
}
}
tasks.joinAll()
assertThat(numLaunches.get()).isEqualTo(1)
}
@Test
fun `close flow after debounce period`() = runBlocking<Unit> {
val numLaunches = AtomicInteger(0)
val flow = flow {
numLaunches.incrementAndGet()
emit("A")
delay(999)
}.share(debounceMs = 20)
val tasks = ArrayList<Job>()
tasks += launch {
withTimeoutOrNull(10) {
flow.collect {}
}
}
delay(100)
tasks += launch {
withTimeoutOrNull(10) {
flow.collect {}
}
}
tasks.joinAll()
tasks.joinAll()
assertThat(numLaunches.get()).isEqualTo(2)
}
@Test
fun `do not leak actor to GlobalScope`() = runBlocking<Unit> {
val flow = createWeakFlow()
System.gc()
delay(100)
System.gc()
assertThat(flow.get()).isNull()
}
private suspend fun createWeakFlow(): WeakReference<Flow<String>> {
val flow = flow {
emit("A")
delay(999)
}.share(debounceMs = 50)
flow.collect {}
return WeakReference(
flow
)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment