Last active
May 28, 2018 11:42
-
-
Save streetsofboston/6ea225c61566e6d349883082fbb9f020 to your computer and use it in GitHub Desktop.
Kotlin Unit Tests Util for having testable functions using Coroutines: TestCoroutineContext
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@file:Suppress("PackageDirectoryMismatch") | |
/* | |
* Copyright (c) 2018 Intrepid Pursuits,Inc. All rights reserved. | |
*/ | |
package kotlinx.coroutines.experimental.intrepid | |
import kotlinx.coroutines.experimental.* | |
import java.util.concurrent.PriorityBlockingQueue | |
import java.util.concurrent.TimeUnit | |
import kotlin.coroutines.experimental.CoroutineContext | |
private const val MAX_DELAY = Long.MAX_VALUE / 2 // cannot delay for too long on Android | |
@Suppress("unused") | |
/** | |
* This [CoroutineContext] dispatcher can be used to simulate virtual time to speed up | |
* code, especially tests, that deal with delays and timeouts. | |
* | |
* Specify an instance of this context when calling the *non-blocking* [kotlinx.coroutines.experimental.launch] | |
* or [kotlinx.coroutines.experimental.async] and then advance time or trigger the actions | |
* to make the co-routines execute. | |
* | |
* This works much like the *TestScheduler* in RxJava, which allows to speed up tests that deal | |
* with non-blocking Rx chains that contain delays or timeouts. | |
* | |
* This dispatcher can also handle *blocking* coroutines that are started by | |
* [kotlinx.coroutines.experimental.runBlocking]. This dispatcher's virtual time will be automatically | |
* advanced based based on the delayed actions within the coroutine(s). | |
*/ | |
class TestCoroutineContext(private val name: String? = null) : CoroutineContext { | |
private val handler = TestHandler() | |
private val context: CoroutineContext | |
private val caughtExceptions = mutableListOf<Throwable>() | |
/** | |
* Exceptions that were caught during a [launch] or a [async] + [Deferred.await]. | |
*/ | |
val exceptions: List<Throwable> get() = caughtExceptions | |
init { | |
context = Dispatcher() + CoroutineExceptionHandler(this::handleException) | |
} | |
override fun <R> fold(initial: R, operation: (R, CoroutineContext.Element) -> R): R = | |
context.fold(initial, operation) | |
override fun <E : CoroutineContext.Element> get(key: CoroutineContext.Key<E>): E? = context[key] | |
override fun minusKey(key: CoroutineContext.Key<*>) = context.minusKey(key) | |
/** | |
* @return The current virtual clock-time as it is known to this CoroutineContext | |
*/ | |
fun now(unit: TimeUnit = TimeUnit.MILLISECONDS) = handler.now(unit) | |
/** | |
* Moves the CoroutineContext's virtual clock forward by a specified amount of time. | |
* | |
* @param delayTime | |
* the amount of time to move the CoroutineContext's clock forward | |
* @param unit | |
* the units of time that [delayTime] is expressed in | |
*/ | |
fun advanceTimeBy(delayTime: Long, unit: TimeUnit = TimeUnit.MILLISECONDS) { | |
handler.advanceTimeBy(delayTime, unit) | |
} | |
/** | |
* Moves the CoroutineContext's clock-time to a particular moment in time. | |
* | |
* @param delayTime | |
* the point in time to move the CoroutineContext's clock to | |
* @param unit | |
* the units of time that [delayTime] is expressed in | |
*/ | |
fun advanceTimeTo(delayTime: Long, unit: TimeUnit = TimeUnit.MILLISECONDS) { | |
handler.advanceTimeTo(delayTime, unit) | |
} | |
/** | |
* Triggers any actions that have not yet been triggered and that are scheduled to be triggered at or | |
* before this CoroutineContext's present virtual time. | |
*/ | |
fun triggerActions() { | |
handler.triggerActions() | |
} | |
/** | |
* Cancels all not yet triggered actions. Be careful calling this, since it can seriously | |
* mess with your coroutines work. This method should usually be called on tear-down of a | |
* unit test. | |
*/ | |
fun cancelAllActions() { | |
handler.cancelAllActions() | |
} | |
override fun toString() = name ?: handler.toString() | |
override fun equals(other: Any?) = other is TestCoroutineContext && other.handler === handler | |
override fun hashCode() = System.identityHashCode(handler) | |
private fun handleException(@Suppress("UNUSED_PARAMETER") context: CoroutineContext, exception: Throwable) { | |
caughtExceptions += exception | |
} | |
private inner class Dispatcher : CoroutineDispatcher(), Delay, EventLoop { | |
override fun dispatch(context: CoroutineContext, block: Runnable) { | |
handler.post(block) | |
} | |
override fun scheduleResumeAfterDelay(time: Long, unit: TimeUnit, continuation: CancellableContinuation<Unit>) { | |
handler.postDelayed(Runnable { | |
with(continuation) { resumeUndispatched(Unit) } | |
}, unit.toMillis(time).coerceAtMost(MAX_DELAY)) | |
} | |
override fun invokeOnTimeout(time: Long, unit: TimeUnit, block: Runnable): DisposableHandle { | |
handler.postDelayed(block, unit.toMillis(time).coerceAtMost(MAX_DELAY)) | |
return object : DisposableHandle { | |
override fun dispose() { | |
handler.removeCallbacks(block) | |
} | |
} | |
} | |
override fun processNextEvent() = handler.processNextEvent() | |
} | |
} | |
private class TestHandler { | |
/** The ordered queue for the runnable tasks. */ | |
private val queue = PriorityBlockingQueue<TimedRunnable>(16) | |
/** The per-scheduler global order counter. */ | |
@Volatile private var counter = 0L | |
// Storing time in nanoseconds internally. | |
@Volatile private var time = 0L | |
private val nextEventTime get() = if (queue.isEmpty()) Long.MAX_VALUE else 0L | |
internal fun post(block: Runnable) { | |
val run = TimedRunnable(block, counter++) | |
queue.add(run) | |
} | |
internal fun postDelayed(block: Runnable, delayTime: Long) { | |
val run = TimedRunnable(block, counter++, time + TimeUnit.MILLISECONDS.toNanos(delayTime)) | |
queue.add(run) | |
} | |
internal fun removeCallbacks(block: Runnable) { | |
queue.remove(TimedRunnable(block)) | |
} | |
internal fun now(unit: TimeUnit) = unit.convert(time, TimeUnit.NANOSECONDS) | |
internal fun advanceTimeBy(delayTime: Long, unit: TimeUnit): Long { | |
val oldTime = time | |
advanceTimeTo(time + unit.toNanos(delayTime), TimeUnit.NANOSECONDS) | |
return unit.convert(time - oldTime, TimeUnit.NANOSECONDS) | |
} | |
internal fun advanceTimeTo(delayTime: Long, unit: TimeUnit) { | |
val targetTime = unit.toNanos(delayTime) | |
triggerActions(targetTime) | |
if (targetTime > time) { | |
time = targetTime | |
} | |
} | |
internal fun triggerActions() { | |
triggerActions(time) | |
} | |
internal fun cancelAllActions() { | |
queue.clear() | |
} | |
internal fun processNextEvent(): Long { | |
val current = queue.peek() | |
if (current != null) { | |
/** Automatically advance time for [EventLoop]-callbacks */ | |
triggerActions(current.time) | |
} | |
return nextEventTime | |
} | |
private fun triggerActions(targetTime: Long) { | |
while (true) { | |
val current = queue.peek() | |
if (current == null || current.time > targetTime) { | |
break | |
} | |
// If the scheduled time is 0 (immediate) use current virtual time | |
time = if (current.time == 0L) time else current.time | |
queue.remove(current) | |
current.run() | |
} | |
} | |
} | |
private class TimedRunnable( | |
private val run: Runnable, | |
private val count: Long = 0, | |
internal val time: Long = 0 | |
) : Comparable<TimedRunnable>, Runnable { | |
override fun run() { | |
run.run() | |
} | |
override fun compareTo(other: TimedRunnable) = if (time == other.time) { | |
count.compareTo(other.count) | |
} else { | |
time.compareTo(other.time) | |
} | |
override fun hashCode() = run.hashCode() | |
override fun equals(other: Any?) = other is TimedRunnable && (run == other.run) | |
override fun toString() = | |
String.format("TimedRunnable(time = %d, run = %s)", time, run.toString()) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@file:Suppress("PackageDirectoryMismatch") | |
/* | |
* Copyright (c) 2018 Intrepid Pursuits,Inc. All rights reserved. | |
*/ | |
package kotlinx.coroutines.experimental.intrepid | |
import kotlinx.coroutines.experimental.* | |
import org.junit.Test | |
import kotlin.test.* | |
@Suppress("FunctionName") | |
class TestCoroutineContextTest { | |
private val context = TestCoroutineContext() | |
@Test | |
fun test_launch_delay() { | |
val delay = 1000L | |
var executed = false | |
launch(context) { | |
suspendedDelay(delay) { | |
executed = true | |
} | |
} | |
context.advanceTimeBy(delay / 2) | |
assertFalse(executed) | |
context.advanceTimeBy(delay / 2) | |
assertTrue(executed) | |
} | |
@Test | |
fun test_async_delay() { | |
val delay = 1000L | |
var executed = false | |
async(context) { | |
suspendedDelay(delay) { | |
executed = true | |
} | |
} | |
context.advanceTimeBy(delay / 2) | |
assertFalse(executed) | |
context.advanceTimeBy(delay / 2) | |
assertTrue(executed) | |
} | |
@Test | |
fun test_blocking_delay() { | |
val delay = 1000L | |
var executed = false | |
runBlocking(context) { | |
suspendedDelay(delay) { | |
executed = true | |
} | |
} | |
assertTrue(executed) | |
assertEquals(delay, context.now()) | |
} | |
private suspend fun suspendedDelay(delay: Long, runAfter: () -> Unit) { | |
delay(delay) | |
runAfter() | |
} | |
@Test | |
fun test_blocking_with_async_result() { | |
val delay = 1000L | |
val expectedValue = 16 | |
val result = runBlocking(context) { | |
suspendedDelayWithResult(delay) { | |
expectedValue | |
} | |
} | |
assertEquals(expectedValue, result) | |
assertEquals(delay, context.now()) | |
} | |
@Test | |
fun test_async_with_async_result() { | |
val delay = 1000L | |
val expectedValue = 16 | |
val deferred = async(context) { | |
suspendedDelayWithResult(delay) { | |
expectedValue | |
} | |
} | |
context.advanceTimeBy(delay / 2) | |
try { | |
deferred.getCompleted() | |
fail("The Job should not have been completed yet.") | |
} catch (e: Exception) { | |
// Success. | |
} | |
context.advanceTimeBy(delay / 2) | |
assertEquals(expectedValue, deferred.getCompleted()) | |
} | |
private suspend fun <T> suspendedDelayWithResult(delay: Long, runAfter: () -> T): T { | |
delay(delay / 4) | |
return async(context) { | |
delay((delay / 4) * 3) | |
runAfter() | |
}.await() | |
} | |
@Test | |
fun test_blocking_with_blocking_result() { | |
val delay = 1000L | |
val expectedValue = 16 | |
val result = runBlocking(context) { | |
blockingDelayWithResult(delay) { | |
expectedValue | |
} | |
} | |
assertEquals(expectedValue, result) | |
assertEquals(delay, context.now()) | |
} | |
@Test | |
fun test_async_with_blocking_result() { | |
val delay = 1000L | |
val expectedValue = 16 | |
val deferred = async(context) { | |
blockingDelayWithResult(delay) { | |
expectedValue | |
} | |
} | |
context.advanceTimeBy((delay / 4) - 1) | |
assertEquals((delay / 4) - 1, context.now()) | |
try { | |
deferred.getCompleted() | |
fail("The Job should not have been completed yet.") | |
} catch (e: Exception) { | |
// Success. | |
} | |
context.advanceTimeBy(1) | |
assertEquals(delay, context.now()) | |
assertEquals(expectedValue, deferred.getCompleted()) | |
} | |
private suspend fun <T> blockingDelayWithResult(delay: Long, runAfter: () -> T): T { | |
delay(delay / 4) | |
return runBlocking(context) { | |
delay((delay / 4) * 3) | |
runAfter() | |
} | |
} | |
@Test | |
fun test_async_delay_and_no_timeout() { | |
val delay = 1000L | |
val expectedValue = 67 | |
val result = async(context) { | |
blockingDelayWithTimeout(delay, delay + 1) { | |
expectedValue | |
} | |
} | |
context.triggerActions() | |
assertEquals(expectedValue, result.getCompleted()) | |
} | |
@Test | |
fun test_async_delay_and_timeout() { | |
val delay = 1000L | |
val expectedValue = 67 | |
val result = async(context) { | |
blockingDelayWithTimeout(delay, delay) { | |
expectedValue | |
} | |
} | |
context.triggerActions() | |
assertTrue(result.getCompletionExceptionOrNull() is TimeoutCancellationException) | |
} | |
@Test | |
fun test_blocking_delay_and_timeout() { | |
val delay = 1000L | |
val expectedValue = 67 | |
try { | |
runBlocking(context) { | |
blockingDelayWithTimeout(delay, delay) { | |
expectedValue | |
} | |
} | |
fail("Expected TimeoutCancellationException to be thrown.") | |
} catch (e: TimeoutCancellationException) { | |
// Success | |
} catch (e: Throwable) { | |
fail("Expected TimeoutCancellationException to be thrown: $e") | |
} | |
} | |
private suspend fun <T> blockingDelayWithTimeout(delay: Long, timeOut: Long, runAfter: () -> T): T { | |
return runBlocking(context) { | |
withTimeout(timeOut) { | |
delay(delay / 2) | |
val ret = runAfter() | |
delay(delay / 2) | |
ret | |
} | |
} | |
} | |
@Test | |
fun test_launch_with_exception() { | |
val expectedError = IllegalAccessError("hello") | |
launch(context) { | |
throw expectedError | |
} | |
context.triggerActions() | |
assertTrue(expectedError === context.exceptions[0]) | |
} | |
@Test | |
fun test_launch_with_child_exception() { | |
val delay = 1000L | |
val expectedError = IllegalAccessError("hello") | |
val expectedValue = 12 | |
launch(context) { | |
asyncDelayWithException(delay, expectedError, expectedValue, true) | |
} | |
context.advanceTimeBy(delay) | |
assertTrue(expectedError === context.exceptions[0]) | |
} | |
@Test | |
fun test_async_with_exception_no_await() { | |
val delay = 1000L | |
val expectedError = IllegalAccessError("hello") | |
val expectedValue = 12 | |
val result = async(context) { | |
asyncDelayWithException(delay, expectedError, expectedValue, false) | |
} | |
context.advanceTimeBy(delay) | |
assertNull(result.getCompletionExceptionOrNull()) | |
assertEquals(expectedValue, result.getCompleted()) | |
} | |
@Test | |
fun test_async_with_exception_await() { | |
val delay = 1000L | |
val expectedError = IllegalAccessError("hello") | |
val expectedValue = 12 | |
val result = async(context) { | |
asyncDelayWithException(delay, expectedError, expectedValue, true) | |
} | |
context.advanceTimeBy(delay) | |
val e = result.getCompletionExceptionOrNull() | |
assertTrue(expectedError === e, "Expected to be thrown: '$expectedError' but was '$e'") | |
} | |
@Test | |
fun test_blocking_with_exception_no_await() { | |
val delay = 1000L | |
val expectedError = IllegalAccessError("hello") | |
val expectedValue = 12 | |
val result = runBlocking(context) { | |
asyncDelayWithException(delay, expectedError, expectedValue, false) | |
} | |
context.advanceTimeBy(delay) | |
assertEquals(expectedValue, result) | |
} | |
@Test | |
fun test_blocking_with_exception_await() { | |
val delay = 1000L | |
val expectedError = IllegalAccessError("hello") | |
val expectedValue = 12 | |
try { | |
runBlocking(context) { | |
asyncDelayWithException(delay, expectedError, expectedValue, true) | |
} | |
fail("Expected to be thrown: '$expectedError'") | |
} catch (e: AssertionError) { | |
throw e | |
} catch (e: Throwable) { | |
assertTrue(expectedError === e, "Expected to be thrown: '$expectedError' but was '$e'") | |
} | |
} | |
private suspend fun <T> asyncDelayWithException(delay: Long, exception: Throwable, value: T, await: Boolean): T { | |
val deferred = async(context) { | |
delay(delay - 1) | |
throw exception | |
} | |
if (await) { | |
deferred.await() | |
} | |
return value | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment