Skip to content

Instantly share code, notes, and snippets.

@streetsofboston
Last active May 28, 2018 11:42
Show Gist options
  • Save streetsofboston/6ea225c61566e6d349883082fbb9f020 to your computer and use it in GitHub Desktop.
Save streetsofboston/6ea225c61566e6d349883082fbb9f020 to your computer and use it in GitHub Desktop.
Kotlin Unit Tests Util for having testable functions using Coroutines: TestCoroutineContext
@file:Suppress("PackageDirectoryMismatch")
/*
* Copyright (c) 2018 Intrepid Pursuits,Inc. All rights reserved.
*/
package kotlinx.coroutines.experimental.intrepid
import kotlinx.coroutines.experimental.*
import java.util.concurrent.PriorityBlockingQueue
import java.util.concurrent.TimeUnit
import kotlin.coroutines.experimental.CoroutineContext
private const val MAX_DELAY = Long.MAX_VALUE / 2 // cannot delay for too long on Android
@Suppress("unused")
/**
* This [CoroutineContext] dispatcher can be used to simulate virtual time to speed up
* code, especially tests, that deal with delays and timeouts.
*
* Specify an instance of this context when calling the *non-blocking* [kotlinx.coroutines.experimental.launch]
* or [kotlinx.coroutines.experimental.async] and then advance time or trigger the actions
* to make the co-routines execute.
*
* This works much like the *TestScheduler* in RxJava, which allows to speed up tests that deal
* with non-blocking Rx chains that contain delays or timeouts.
*
* This dispatcher can also handle *blocking* coroutines that are started by
* [kotlinx.coroutines.experimental.runBlocking]. This dispatcher's virtual time will be automatically
* advanced based based on the delayed actions within the coroutine(s).
*/
class TestCoroutineContext(private val name: String? = null) : CoroutineContext {
private val handler = TestHandler()
private val context: CoroutineContext
private val caughtExceptions = mutableListOf<Throwable>()
/**
* Exceptions that were caught during a [launch] or a [async] + [Deferred.await].
*/
val exceptions: List<Throwable> get() = caughtExceptions
init {
context = Dispatcher() + CoroutineExceptionHandler(this::handleException)
}
override fun <R> fold(initial: R, operation: (R, CoroutineContext.Element) -> R): R =
context.fold(initial, operation)
override fun <E : CoroutineContext.Element> get(key: CoroutineContext.Key<E>): E? = context[key]
override fun minusKey(key: CoroutineContext.Key<*>) = context.minusKey(key)
/**
* @return The current virtual clock-time as it is known to this CoroutineContext
*/
fun now(unit: TimeUnit = TimeUnit.MILLISECONDS) = handler.now(unit)
/**
* Moves the CoroutineContext's virtual clock forward by a specified amount of time.
*
* @param delayTime
* the amount of time to move the CoroutineContext's clock forward
* @param unit
* the units of time that [delayTime] is expressed in
*/
fun advanceTimeBy(delayTime: Long, unit: TimeUnit = TimeUnit.MILLISECONDS) {
handler.advanceTimeBy(delayTime, unit)
}
/**
* Moves the CoroutineContext's clock-time to a particular moment in time.
*
* @param delayTime
* the point in time to move the CoroutineContext's clock to
* @param unit
* the units of time that [delayTime] is expressed in
*/
fun advanceTimeTo(delayTime: Long, unit: TimeUnit = TimeUnit.MILLISECONDS) {
handler.advanceTimeTo(delayTime, unit)
}
/**
* Triggers any actions that have not yet been triggered and that are scheduled to be triggered at or
* before this CoroutineContext's present virtual time.
*/
fun triggerActions() {
handler.triggerActions()
}
/**
* Cancels all not yet triggered actions. Be careful calling this, since it can seriously
* mess with your coroutines work. This method should usually be called on tear-down of a
* unit test.
*/
fun cancelAllActions() {
handler.cancelAllActions()
}
override fun toString() = name ?: handler.toString()
override fun equals(other: Any?) = other is TestCoroutineContext && other.handler === handler
override fun hashCode() = System.identityHashCode(handler)
private fun handleException(@Suppress("UNUSED_PARAMETER") context: CoroutineContext, exception: Throwable) {
caughtExceptions += exception
}
private inner class Dispatcher : CoroutineDispatcher(), Delay, EventLoop {
override fun dispatch(context: CoroutineContext, block: Runnable) {
handler.post(block)
}
override fun scheduleResumeAfterDelay(time: Long, unit: TimeUnit, continuation: CancellableContinuation<Unit>) {
handler.postDelayed(Runnable {
with(continuation) { resumeUndispatched(Unit) }
}, unit.toMillis(time).coerceAtMost(MAX_DELAY))
}
override fun invokeOnTimeout(time: Long, unit: TimeUnit, block: Runnable): DisposableHandle {
handler.postDelayed(block, unit.toMillis(time).coerceAtMost(MAX_DELAY))
return object : DisposableHandle {
override fun dispose() {
handler.removeCallbacks(block)
}
}
}
override fun processNextEvent() = handler.processNextEvent()
}
}
private class TestHandler {
/** The ordered queue for the runnable tasks. */
private val queue = PriorityBlockingQueue<TimedRunnable>(16)
/** The per-scheduler global order counter. */
@Volatile private var counter = 0L
// Storing time in nanoseconds internally.
@Volatile private var time = 0L
private val nextEventTime get() = if (queue.isEmpty()) Long.MAX_VALUE else 0L
internal fun post(block: Runnable) {
val run = TimedRunnable(block, counter++)
queue.add(run)
}
internal fun postDelayed(block: Runnable, delayTime: Long) {
val run = TimedRunnable(block, counter++, time + TimeUnit.MILLISECONDS.toNanos(delayTime))
queue.add(run)
}
internal fun removeCallbacks(block: Runnable) {
queue.remove(TimedRunnable(block))
}
internal fun now(unit: TimeUnit) = unit.convert(time, TimeUnit.NANOSECONDS)
internal fun advanceTimeBy(delayTime: Long, unit: TimeUnit): Long {
val oldTime = time
advanceTimeTo(time + unit.toNanos(delayTime), TimeUnit.NANOSECONDS)
return unit.convert(time - oldTime, TimeUnit.NANOSECONDS)
}
internal fun advanceTimeTo(delayTime: Long, unit: TimeUnit) {
val targetTime = unit.toNanos(delayTime)
triggerActions(targetTime)
if (targetTime > time) {
time = targetTime
}
}
internal fun triggerActions() {
triggerActions(time)
}
internal fun cancelAllActions() {
queue.clear()
}
internal fun processNextEvent(): Long {
val current = queue.peek()
if (current != null) {
/** Automatically advance time for [EventLoop]-callbacks */
triggerActions(current.time)
}
return nextEventTime
}
private fun triggerActions(targetTime: Long) {
while (true) {
val current = queue.peek()
if (current == null || current.time > targetTime) {
break
}
// If the scheduled time is 0 (immediate) use current virtual time
time = if (current.time == 0L) time else current.time
queue.remove(current)
current.run()
}
}
}
private class TimedRunnable(
private val run: Runnable,
private val count: Long = 0,
internal val time: Long = 0
) : Comparable<TimedRunnable>, Runnable {
override fun run() {
run.run()
}
override fun compareTo(other: TimedRunnable) = if (time == other.time) {
count.compareTo(other.count)
} else {
time.compareTo(other.time)
}
override fun hashCode() = run.hashCode()
override fun equals(other: Any?) = other is TimedRunnable && (run == other.run)
override fun toString() =
String.format("TimedRunnable(time = %d, run = %s)", time, run.toString())
}
@file:Suppress("PackageDirectoryMismatch")
/*
* Copyright (c) 2018 Intrepid Pursuits,Inc. All rights reserved.
*/
package kotlinx.coroutines.experimental.intrepid
import kotlinx.coroutines.experimental.*
import org.junit.Test
import kotlin.test.*
@Suppress("FunctionName")
class TestCoroutineContextTest {
private val context = TestCoroutineContext()
@Test
fun test_launch_delay() {
val delay = 1000L
var executed = false
launch(context) {
suspendedDelay(delay) {
executed = true
}
}
context.advanceTimeBy(delay / 2)
assertFalse(executed)
context.advanceTimeBy(delay / 2)
assertTrue(executed)
}
@Test
fun test_async_delay() {
val delay = 1000L
var executed = false
async(context) {
suspendedDelay(delay) {
executed = true
}
}
context.advanceTimeBy(delay / 2)
assertFalse(executed)
context.advanceTimeBy(delay / 2)
assertTrue(executed)
}
@Test
fun test_blocking_delay() {
val delay = 1000L
var executed = false
runBlocking(context) {
suspendedDelay(delay) {
executed = true
}
}
assertTrue(executed)
assertEquals(delay, context.now())
}
private suspend fun suspendedDelay(delay: Long, runAfter: () -> Unit) {
delay(delay)
runAfter()
}
@Test
fun test_blocking_with_async_result() {
val delay = 1000L
val expectedValue = 16
val result = runBlocking(context) {
suspendedDelayWithResult(delay) {
expectedValue
}
}
assertEquals(expectedValue, result)
assertEquals(delay, context.now())
}
@Test
fun test_async_with_async_result() {
val delay = 1000L
val expectedValue = 16
val deferred = async(context) {
suspendedDelayWithResult(delay) {
expectedValue
}
}
context.advanceTimeBy(delay / 2)
try {
deferred.getCompleted()
fail("The Job should not have been completed yet.")
} catch (e: Exception) {
// Success.
}
context.advanceTimeBy(delay / 2)
assertEquals(expectedValue, deferred.getCompleted())
}
private suspend fun <T> suspendedDelayWithResult(delay: Long, runAfter: () -> T): T {
delay(delay / 4)
return async(context) {
delay((delay / 4) * 3)
runAfter()
}.await()
}
@Test
fun test_blocking_with_blocking_result() {
val delay = 1000L
val expectedValue = 16
val result = runBlocking(context) {
blockingDelayWithResult(delay) {
expectedValue
}
}
assertEquals(expectedValue, result)
assertEquals(delay, context.now())
}
@Test
fun test_async_with_blocking_result() {
val delay = 1000L
val expectedValue = 16
val deferred = async(context) {
blockingDelayWithResult(delay) {
expectedValue
}
}
context.advanceTimeBy((delay / 4) - 1)
assertEquals((delay / 4) - 1, context.now())
try {
deferred.getCompleted()
fail("The Job should not have been completed yet.")
} catch (e: Exception) {
// Success.
}
context.advanceTimeBy(1)
assertEquals(delay, context.now())
assertEquals(expectedValue, deferred.getCompleted())
}
private suspend fun <T> blockingDelayWithResult(delay: Long, runAfter: () -> T): T {
delay(delay / 4)
return runBlocking(context) {
delay((delay / 4) * 3)
runAfter()
}
}
@Test
fun test_async_delay_and_no_timeout() {
val delay = 1000L
val expectedValue = 67
val result = async(context) {
blockingDelayWithTimeout(delay, delay + 1) {
expectedValue
}
}
context.triggerActions()
assertEquals(expectedValue, result.getCompleted())
}
@Test
fun test_async_delay_and_timeout() {
val delay = 1000L
val expectedValue = 67
val result = async(context) {
blockingDelayWithTimeout(delay, delay) {
expectedValue
}
}
context.triggerActions()
assertTrue(result.getCompletionExceptionOrNull() is TimeoutCancellationException)
}
@Test
fun test_blocking_delay_and_timeout() {
val delay = 1000L
val expectedValue = 67
try {
runBlocking(context) {
blockingDelayWithTimeout(delay, delay) {
expectedValue
}
}
fail("Expected TimeoutCancellationException to be thrown.")
} catch (e: TimeoutCancellationException) {
// Success
} catch (e: Throwable) {
fail("Expected TimeoutCancellationException to be thrown: $e")
}
}
private suspend fun <T> blockingDelayWithTimeout(delay: Long, timeOut: Long, runAfter: () -> T): T {
return runBlocking(context) {
withTimeout(timeOut) {
delay(delay / 2)
val ret = runAfter()
delay(delay / 2)
ret
}
}
}
@Test
fun test_launch_with_exception() {
val expectedError = IllegalAccessError("hello")
launch(context) {
throw expectedError
}
context.triggerActions()
assertTrue(expectedError === context.exceptions[0])
}
@Test
fun test_launch_with_child_exception() {
val delay = 1000L
val expectedError = IllegalAccessError("hello")
val expectedValue = 12
launch(context) {
asyncDelayWithException(delay, expectedError, expectedValue, true)
}
context.advanceTimeBy(delay)
assertTrue(expectedError === context.exceptions[0])
}
@Test
fun test_async_with_exception_no_await() {
val delay = 1000L
val expectedError = IllegalAccessError("hello")
val expectedValue = 12
val result = async(context) {
asyncDelayWithException(delay, expectedError, expectedValue, false)
}
context.advanceTimeBy(delay)
assertNull(result.getCompletionExceptionOrNull())
assertEquals(expectedValue, result.getCompleted())
}
@Test
fun test_async_with_exception_await() {
val delay = 1000L
val expectedError = IllegalAccessError("hello")
val expectedValue = 12
val result = async(context) {
asyncDelayWithException(delay, expectedError, expectedValue, true)
}
context.advanceTimeBy(delay)
val e = result.getCompletionExceptionOrNull()
assertTrue(expectedError === e, "Expected to be thrown: '$expectedError' but was '$e'")
}
@Test
fun test_blocking_with_exception_no_await() {
val delay = 1000L
val expectedError = IllegalAccessError("hello")
val expectedValue = 12
val result = runBlocking(context) {
asyncDelayWithException(delay, expectedError, expectedValue, false)
}
context.advanceTimeBy(delay)
assertEquals(expectedValue, result)
}
@Test
fun test_blocking_with_exception_await() {
val delay = 1000L
val expectedError = IllegalAccessError("hello")
val expectedValue = 12
try {
runBlocking(context) {
asyncDelayWithException(delay, expectedError, expectedValue, true)
}
fail("Expected to be thrown: '$expectedError'")
} catch (e: AssertionError) {
throw e
} catch (e: Throwable) {
assertTrue(expectedError === e, "Expected to be thrown: '$expectedError' but was '$e'")
}
}
private suspend fun <T> asyncDelayWithException(delay: Long, exception: Throwable, value: T, await: Boolean): T {
val deferred = async(context) {
delay(delay - 1)
throw exception
}
if (await) {
deferred.await()
}
return value
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment