Last active
October 1, 2023 22:16
-
-
Save paulo-raca/ef6a827046a5faec95024ff406d3a692 to your computer and use it in GitHub Desktop.
Condition Variables for Kotlin Coroutines
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package kotlinx.coroutines.sync | |
import kotlinx.coroutines.TimeoutCancellationException | |
import kotlinx.coroutines.sync.Mutex | |
import kotlinx.coroutines.withTimeout | |
import java.util.function.Predicate | |
import kotlin.time.Duration | |
import kotlin.time.ExperimentalTime | |
import kotlin.time.nanoseconds | |
/** | |
* This should be part of kotlin-coroutines: https://github.com/Kotlin/kotlinx.coroutines/issues/2531 | |
*/ | |
class Condition(val mutex: Mutex) { | |
val waiting = LinkedHashSet<Mutex>() | |
/** | |
* Blocks this coroutine until the predicate is true or the specified timeout has elapsed | |
* | |
* The associated mutex is unlocked while this coroutine is awaiting | |
* | |
* @return true If this coroutine was waked by signal() or signalAll(), false if the timeout has elapsed | |
*/ | |
@ExperimentalTime | |
suspend fun awaitUntil(timeout: Duration = Duration.INFINITE, owner: Any? = null, predicate: () -> Boolean): Boolean { | |
val start = System.nanoTime() | |
while (!predicate()) { | |
val elapsed = (System.nanoTime() - start).nanoseconds | |
val remainingTimeout = timeout - elapsed | |
if (remainingTimeout < Duration.ZERO) { | |
return false // Timeout elapsed without success | |
} | |
await(remainingTimeout, owner) | |
} | |
return true | |
} | |
/** | |
* Blocks this coroutine until unblocked by signal() or signalAll(), or the specified timeout has elapsed | |
* | |
* The associated mutex is unlocked while this coroutine is awaiting | |
* | |
* @return true If this coroutine was waked by signal() or signalAll(), false if the timeout has elapsed | |
*/ | |
@ExperimentalTime | |
suspend fun await(timeout: Duration = Duration.INFINITE, owner: Any? = null): Boolean { | |
ensureLocked(owner, "wait") | |
val waiter = Mutex(true) | |
waiting.add(waiter) | |
mutex.unlock(owner) | |
try { | |
withTimeout(timeout) { | |
waiter.lock() | |
} | |
return true | |
} catch (e: TimeoutCancellationException) { | |
return false | |
} finally { | |
mutex.lock(owner) | |
waiting.remove(waiter) | |
} | |
} | |
/** | |
* Wakes up one coroutine blocked in await() | |
*/ | |
suspend fun signal(owner: Any? = null) { | |
ensureLocked(owner, "notify") | |
val it = waiting.iterator() | |
if (it.hasNext()) { | |
val waiter = it.next() | |
it.remove() | |
waiter.unlock() | |
} | |
} | |
/** | |
* Wakes up all coroutines blocked in await() | |
*/ | |
suspend fun signalAll(owner: Any? = null) { | |
ensureLocked(owner, "notifyAll") | |
val it = waiting.iterator() | |
while (it.hasNext()) { | |
val waiter = it.next() | |
it.remove() | |
waiter.unlock() | |
} | |
} | |
internal fun ensureLocked(owner: Any?, funcName: String) { | |
val isLocked = if (owner == null) mutex.isLocked else mutex.holdsLock(owner) | |
if (!isLocked) { | |
throw IllegalStateException("${funcName} requires a locked mutex") | |
} | |
} | |
} | |
fun Mutex.newCondition(): Condition { | |
return Condition(this) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package kotlinx.coroutines.sync | |
import kotlinx.coroutines.sync.newCondition | |
import kotlinx.coroutines.* | |
import kotlinx.coroutines.sync.Mutex | |
import kotlinx.coroutines.sync.withLock | |
import org.junit.jupiter.api.Assertions | |
import org.junit.jupiter.api.Test | |
import java.util.stream.Collectors | |
import java.util.stream.IntStream | |
import kotlin.time.ExperimentalTime | |
import kotlin.time.milliseconds | |
import kotlin.time.seconds | |
@ExperimentalTime | |
class ConditionTest { | |
val lock = Mutex() | |
val cond = lock.newCondition() | |
@Test | |
fun testAwaitWithoutSignal() { | |
runBlocking { | |
lock.withLock { | |
Assertions.assertFalse(cond.await(1.seconds)) | |
} | |
} | |
} | |
@Test | |
fun testAwaitSignal() { | |
runBlocking { | |
launch { | |
delay(500) | |
lock.withLock { | |
cond.signal() | |
} | |
} | |
lock.withLock { | |
Assertions.assertTrue(cond.await(1.seconds)) | |
Assertions.assertFalse(cond.await(1.seconds)) | |
} | |
} | |
} | |
@Test | |
fun testSignalAwait() { | |
runBlocking { | |
lock.withLock { | |
cond.signal() | |
} | |
lock.withLock { | |
delay(500) | |
Assertions.assertFalse(cond.await(1.seconds)) | |
} | |
} | |
} | |
@Test | |
fun testNotifyOnce() { | |
runBlocking { | |
val waiters = IntStream.range(0, 5) | |
.mapToObj { i -> | |
async<Boolean> { | |
lock.withLock { | |
val ret = cond.await(1.seconds) | |
ret | |
} | |
} | |
} | |
.collect(Collectors.toList()) | |
.toTypedArray() | |
delay(100.milliseconds) | |
lock.withLock { | |
cond.signal() | |
} | |
val results = awaitAll(*waiters) | |
val successCount = results.stream() | |
.map { ret -> if (ret) 1 else 0 } | |
.reduce { a, b -> a + b } | |
.get() | |
Assertions.assertEquals(1, successCount) | |
} | |
} | |
@Test | |
fun testNotifyAll() { | |
runBlocking { | |
val waiters = IntStream.range(0, 5) | |
.mapToObj { i -> | |
async<Boolean> { | |
lock.withLock { | |
val ret = cond.await(1.seconds) | |
ret | |
} | |
} | |
} | |
.collect(Collectors.toList()) | |
.toTypedArray() | |
delay(100.milliseconds) | |
lock.withLock { | |
cond.signalAll() | |
} | |
val results = awaitAll(*waiters) | |
val successCount = results.stream() | |
.map { ret -> if (ret) 1 else 0 } | |
.reduce { a, b -> a + b } | |
.get() | |
Assertions.assertEquals(results.size, successCount) | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment