Skip to content

Instantly share code, notes, and snippets.

@konrad-kaminski
Created May 9, 2017 23:10
Show Gist options
  • Save konrad-kaminski/74942a238bcac5318c4c1b3a464a4e77 to your computer and use it in GitHub Desktop.
Save konrad-kaminski/74942a238bcac5318c4c1b3a464a4e77 to your computer and use it in GitHub Desktop.
CountDownLatch naive implementation based on ConflatedBroadcastChannel
/*
* Copyright 2016-2017 JetBrains s.r.o.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package kotlinx.coroutines.experimental.sync
import kotlinx.coroutines.experimental.CancellationException
import kotlinx.coroutines.experimental.channels.ConflatedBroadcastChannel
import kotlinx.coroutines.experimental.channels.consumeEach
import kotlinx.coroutines.experimental.withTimeoutOrNull
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger
/**
* Equivalent of [java.util.concurrent.CountDownLatch] for coroutines.
*/
interface CountDownLatch {
/**
* Decrements the count of the latch, resuming all suspended coroutines if
* the count reaches zero.
*
* If the current count is greater than zero then it is decremented.
* If the new count is zero then all suspended coroutines are resumed.
*
* If the current count equals zero then nothing happens.
*/
suspend fun countDown()
/**
* Returns the current count.
*
* This method is typically used for debugging and testing purposes.
*
* @return the current count
*/
fun getCount(): Long
/**
* Causes the current coroutine to suspend until the latch has counted down to
* zero, unless the couroutine is cancelled.
*
* If the current count is zero then this method returns immediately.
*
* If the current count is greater than zero then the current
* coroutine is suspended and awaits until one of two things happen:
*
* * The count reaches zero due to invocations of the [countDown] method; or
* * the coroutine is cancelled.
*
* If the current coroutine:
*
* * is already cancelled; or
* * is cancelled while waiting,
*
* then [CancellationException] is thrown.
*
* @throws CancellationException if the current coroutine is cancelled
* while waiting or is already cancelled
*/
@Throws(CancellationException::class)
suspend fun await()
/**
* Causes the current coroutine to suspend until the latch has counted down to
* zero, unless the couroutine is cancelled.
*
* If the current count is zero then this method returns immediately.
*
* If the current count is greater than zero then the current
* coroutine is suspended and awaits until one of two things happen:
*
* * The count reaches zero due to invocations of the [countDown] method; or
* * the coroutine is cancelled; or
* * the specified waiting time elapses.
*
* If the current coroutine:
*
* * is already cancelled; or
* * is cancelled while waiting,
*
* then [CancellationException] is thrown.
*
* If the specified waiting time elapses then the value `false`
* is returned. If the time is less than or equal to zero, the method
* will not wait at all.
*
* @param timeout the maximum time to wait
* @param unit the time unit of the `timeout` argument
* @return `true` if the count reached zero and `false`
* if the waiting time elapsed before the count reached zero
* @throws CancellationException if the current coroutine is cancelled
* while waiting or is already cancelled
*/
@Throws(CancellationException::class)
suspend fun await(time: Long, unit: TimeUnit = TimeUnit.MILLISECONDS): Boolean
/**
* Factory for [CountDownLatch] instances.
*/
companion object {
/**
* Creates new [CountDownLatch] instance.
*
* @param initialCount initial count of the latch.
*/
operator fun invoke(initialCount: Int): CountDownLatch = CountDownLatchImpl(initialCount)
}
}
internal class CountDownLatchImpl(initialCount: Int) : CountDownLatch {
private val count = AtomicInteger(initialCount)
private val channel = ConflatedBroadcastChannel(false)
init {
if (initialCount < 0) {
throw IllegalArgumentException("initialCount < 0")
}
}
suspend override fun countDown() {
val counterBeforeUpdate = count.getAndUpdate { counter ->
if (counter > 0) counter-1 else 0
}
if (counterBeforeUpdate == 1) {
try {
channel.send(true)
}
finally {
channel.close()
}
}
}
override fun getCount() = count.toLong()
suspend override fun await(time: Long, unit: TimeUnit): Boolean =
withTimeoutOrNull(time, unit) { await() } != null
suspend override fun await() =
channel.consumeEach {
if (it) return@consumeEach
}
}
/*
* Copyright 2016-2017 JetBrains s.r.o.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package kotlinx.coroutines.experimental.sync
import kotlinx.coroutines.experimental.TestBase
import kotlinx.coroutines.experimental.launch
import kotlinx.coroutines.experimental.runBlocking
import kotlinx.coroutines.experimental.yield
import org.junit.Assert.assertEquals
import org.junit.Test
class CountDownLatchTest : TestBase() {
@Test
fun testSimple() = runBlocking {
val latch = CountDownLatch(2)
expect(1)
launch(context) {
expect(4)
latch.await() // suspends
expect(7) // now latch is down
}
expect(2)
latch.countDown()
expect(3)
yield()
expect(5)
latch.countDown()
expect(6)
yield()
finish(8)
}
@Test
fun countDownTest() = runBlocking {
val latch = CountDownLatch(3)
assertEquals(3, latch.getCount())
latch.countDown()
assertEquals(2, latch.getCount())
latch.countDown()
assertEquals(1, latch.getCount())
latch.countDown()
assertEquals(0, latch.getCount())
latch.await()
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment