Skip to content

Instantly share code, notes, and snippets.

Last active May 28, 2018 11:42
Show Gist options
  • Save streetsofboston/6ea225c61566e6d349883082fbb9f020 to your computer and use it in GitHub Desktop.
Save streetsofboston/6ea225c61566e6d349883082fbb9f020 to your computer and use it in GitHub Desktop.
Kotlin Unit Tests Util for having testable functions using Coroutines: TestCoroutineContext
* Copyright (c) 2018 Intrepid Pursuits,Inc. All rights reserved.
package kotlinx.coroutines.experimental.intrepid
import kotlinx.coroutines.experimental.*
import java.util.concurrent.PriorityBlockingQueue
import java.util.concurrent.TimeUnit
import kotlin.coroutines.experimental.CoroutineContext
private const val MAX_DELAY = Long.MAX_VALUE / 2 // cannot delay for too long on Android
* This [CoroutineContext] dispatcher can be used to simulate virtual time to speed up
* code, especially tests, that deal with delays and timeouts.
* Specify an instance of this context when calling the *non-blocking* [kotlinx.coroutines.experimental.launch]
* or [kotlinx.coroutines.experimental.async] and then advance time or trigger the actions
* to make the co-routines execute.
* This works much like the *TestScheduler* in RxJava, which allows to speed up tests that deal
* with non-blocking Rx chains that contain delays or timeouts.
* This dispatcher can also handle *blocking* coroutines that are started by
* [kotlinx.coroutines.experimental.runBlocking]. This dispatcher's virtual time will be automatically
* advanced based based on the delayed actions within the coroutine(s).
class TestCoroutineContext(private val name: String? = null) : CoroutineContext {
private val handler = TestHandler()
private val context: CoroutineContext
private val caughtExceptions = mutableListOf<Throwable>()
* Exceptions that were caught during a [launch] or a [async] + [Deferred.await].
val exceptions: List<Throwable> get() = caughtExceptions
init {
context = Dispatcher() + CoroutineExceptionHandler(this::handleException)
override fun <R> fold(initial: R, operation: (R, CoroutineContext.Element) -> R): R =
context.fold(initial, operation)
override fun <E : CoroutineContext.Element> get(key: CoroutineContext.Key<E>): E? = context[key]
override fun minusKey(key: CoroutineContext.Key<*>) = context.minusKey(key)
* @return The current virtual clock-time as it is known to this CoroutineContext
fun now(unit: TimeUnit = TimeUnit.MILLISECONDS) =
* Moves the CoroutineContext's virtual clock forward by a specified amount of time.
* @param delayTime
* the amount of time to move the CoroutineContext's clock forward
* @param unit
* the units of time that [delayTime] is expressed in
fun advanceTimeBy(delayTime: Long, unit: TimeUnit = TimeUnit.MILLISECONDS) {
handler.advanceTimeBy(delayTime, unit)
* Moves the CoroutineContext's clock-time to a particular moment in time.
* @param delayTime
* the point in time to move the CoroutineContext's clock to
* @param unit
* the units of time that [delayTime] is expressed in
fun advanceTimeTo(delayTime: Long, unit: TimeUnit = TimeUnit.MILLISECONDS) {
handler.advanceTimeTo(delayTime, unit)
* Triggers any actions that have not yet been triggered and that are scheduled to be triggered at or
* before this CoroutineContext's present virtual time.
fun triggerActions() {
* Cancels all not yet triggered actions. Be careful calling this, since it can seriously
* mess with your coroutines work. This method should usually be called on tear-down of a
* unit test.
fun cancelAllActions() {
override fun toString() = name ?: handler.toString()
override fun equals(other: Any?) = other is TestCoroutineContext && other.handler === handler
override fun hashCode() = System.identityHashCode(handler)
private fun handleException(@Suppress("UNUSED_PARAMETER") context: CoroutineContext, exception: Throwable) {
caughtExceptions += exception
private inner class Dispatcher : CoroutineDispatcher(), Delay, EventLoop {
override fun dispatch(context: CoroutineContext, block: Runnable) {
override fun scheduleResumeAfterDelay(time: Long, unit: TimeUnit, continuation: CancellableContinuation<Unit>) {
handler.postDelayed(Runnable {
with(continuation) { resumeUndispatched(Unit) }
}, unit.toMillis(time).coerceAtMost(MAX_DELAY))
override fun invokeOnTimeout(time: Long, unit: TimeUnit, block: Runnable): DisposableHandle {
handler.postDelayed(block, unit.toMillis(time).coerceAtMost(MAX_DELAY))
return object : DisposableHandle {
override fun dispose() {
override fun processNextEvent() = handler.processNextEvent()
private class TestHandler {
/** The ordered queue for the runnable tasks. */
private val queue = PriorityBlockingQueue<TimedRunnable>(16)
/** The per-scheduler global order counter. */
@Volatile private var counter = 0L
// Storing time in nanoseconds internally.
@Volatile private var time = 0L
private val nextEventTime get() = if (queue.isEmpty()) Long.MAX_VALUE else 0L
internal fun post(block: Runnable) {
val run = TimedRunnable(block, counter++)
internal fun postDelayed(block: Runnable, delayTime: Long) {
val run = TimedRunnable(block, counter++, time + TimeUnit.MILLISECONDS.toNanos(delayTime))
internal fun removeCallbacks(block: Runnable) {
internal fun now(unit: TimeUnit) = unit.convert(time, TimeUnit.NANOSECONDS)
internal fun advanceTimeBy(delayTime: Long, unit: TimeUnit): Long {
val oldTime = time
advanceTimeTo(time + unit.toNanos(delayTime), TimeUnit.NANOSECONDS)
return unit.convert(time - oldTime, TimeUnit.NANOSECONDS)
internal fun advanceTimeTo(delayTime: Long, unit: TimeUnit) {
val targetTime = unit.toNanos(delayTime)
if (targetTime > time) {
time = targetTime
internal fun triggerActions() {
internal fun cancelAllActions() {
internal fun processNextEvent(): Long {
val current = queue.peek()
if (current != null) {
/** Automatically advance time for [EventLoop]-callbacks */
return nextEventTime
private fun triggerActions(targetTime: Long) {
while (true) {
val current = queue.peek()
if (current == null || current.time > targetTime) {
// If the scheduled time is 0 (immediate) use current virtual time
time = if (current.time == 0L) time else current.time
private class TimedRunnable(
private val run: Runnable,
private val count: Long = 0,
internal val time: Long = 0
) : Comparable<TimedRunnable>, Runnable {
override fun run() {
override fun compareTo(other: TimedRunnable) = if (time == other.time) {
} else {
override fun hashCode() = run.hashCode()
override fun equals(other: Any?) = other is TimedRunnable && (run ==
override fun toString() =
String.format("TimedRunnable(time = %d, run = %s)", time, run.toString())
* Copyright (c) 2018 Intrepid Pursuits,Inc. All rights reserved.
package kotlinx.coroutines.experimental.intrepid
import kotlinx.coroutines.experimental.*
import org.junit.Test
import kotlin.test.*
class TestCoroutineContextTest {
private val context = TestCoroutineContext()
fun test_launch_delay() {
val delay = 1000L
var executed = false
launch(context) {
suspendedDelay(delay) {
executed = true
context.advanceTimeBy(delay / 2)
context.advanceTimeBy(delay / 2)
fun test_async_delay() {
val delay = 1000L
var executed = false
async(context) {
suspendedDelay(delay) {
executed = true
context.advanceTimeBy(delay / 2)
context.advanceTimeBy(delay / 2)
fun test_blocking_delay() {
val delay = 1000L
var executed = false
runBlocking(context) {
suspendedDelay(delay) {
executed = true
private suspend fun suspendedDelay(delay: Long, runAfter: () -> Unit) {
fun test_blocking_with_async_result() {
val delay = 1000L
val expectedValue = 16
val result = runBlocking(context) {
suspendedDelayWithResult(delay) {
assertEquals(expectedValue, result)
fun test_async_with_async_result() {
val delay = 1000L
val expectedValue = 16
val deferred = async(context) {
suspendedDelayWithResult(delay) {
context.advanceTimeBy(delay / 2)
try {
fail("The Job should not have been completed yet.")
} catch (e: Exception) {
// Success.
context.advanceTimeBy(delay / 2)
assertEquals(expectedValue, deferred.getCompleted())
private suspend fun <T> suspendedDelayWithResult(delay: Long, runAfter: () -> T): T {
delay(delay / 4)
return async(context) {
delay((delay / 4) * 3)
fun test_blocking_with_blocking_result() {
val delay = 1000L
val expectedValue = 16
val result = runBlocking(context) {
blockingDelayWithResult(delay) {
assertEquals(expectedValue, result)
fun test_async_with_blocking_result() {
val delay = 1000L
val expectedValue = 16
val deferred = async(context) {
blockingDelayWithResult(delay) {
context.advanceTimeBy((delay / 4) - 1)
assertEquals((delay / 4) - 1,
try {
fail("The Job should not have been completed yet.")
} catch (e: Exception) {
// Success.
assertEquals(expectedValue, deferred.getCompleted())
private suspend fun <T> blockingDelayWithResult(delay: Long, runAfter: () -> T): T {
delay(delay / 4)
return runBlocking(context) {
delay((delay / 4) * 3)
fun test_async_delay_and_no_timeout() {
val delay = 1000L
val expectedValue = 67
val result = async(context) {
blockingDelayWithTimeout(delay, delay + 1) {
assertEquals(expectedValue, result.getCompleted())
fun test_async_delay_and_timeout() {
val delay = 1000L
val expectedValue = 67
val result = async(context) {
blockingDelayWithTimeout(delay, delay) {
assertTrue(result.getCompletionExceptionOrNull() is TimeoutCancellationException)
fun test_blocking_delay_and_timeout() {
val delay = 1000L
val expectedValue = 67
try {
runBlocking(context) {
blockingDelayWithTimeout(delay, delay) {
fail("Expected TimeoutCancellationException to be thrown.")
} catch (e: TimeoutCancellationException) {
// Success
} catch (e: Throwable) {
fail("Expected TimeoutCancellationException to be thrown: $e")
private suspend fun <T> blockingDelayWithTimeout(delay: Long, timeOut: Long, runAfter: () -> T): T {
return runBlocking(context) {
withTimeout(timeOut) {
delay(delay / 2)
val ret = runAfter()
delay(delay / 2)
fun test_launch_with_exception() {
val expectedError = IllegalAccessError("hello")
launch(context) {
throw expectedError
assertTrue(expectedError === context.exceptions[0])
fun test_launch_with_child_exception() {
val delay = 1000L
val expectedError = IllegalAccessError("hello")
val expectedValue = 12
launch(context) {
asyncDelayWithException(delay, expectedError, expectedValue, true)
assertTrue(expectedError === context.exceptions[0])
fun test_async_with_exception_no_await() {
val delay = 1000L
val expectedError = IllegalAccessError("hello")
val expectedValue = 12
val result = async(context) {
asyncDelayWithException(delay, expectedError, expectedValue, false)
assertEquals(expectedValue, result.getCompleted())
fun test_async_with_exception_await() {
val delay = 1000L
val expectedError = IllegalAccessError("hello")
val expectedValue = 12
val result = async(context) {
asyncDelayWithException(delay, expectedError, expectedValue, true)
val e = result.getCompletionExceptionOrNull()
assertTrue(expectedError === e, "Expected to be thrown: '$expectedError' but was '$e'")
fun test_blocking_with_exception_no_await() {
val delay = 1000L
val expectedError = IllegalAccessError("hello")
val expectedValue = 12
val result = runBlocking(context) {
asyncDelayWithException(delay, expectedError, expectedValue, false)
assertEquals(expectedValue, result)
fun test_blocking_with_exception_await() {
val delay = 1000L
val expectedError = IllegalAccessError("hello")
val expectedValue = 12
try {
runBlocking(context) {
asyncDelayWithException(delay, expectedError, expectedValue, true)
fail("Expected to be thrown: '$expectedError'")
} catch (e: AssertionError) {
throw e
} catch (e: Throwable) {
assertTrue(expectedError === e, "Expected to be thrown: '$expectedError' but was '$e'")
private suspend fun <T> asyncDelayWithException(delay: Long, exception: Throwable, value: T, await: Boolean): T {
val deferred = async(context) {
delay(delay - 1)
throw exception
if (await) {
return value
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment