Created
November 27, 2018 15:54
-
-
Save jshiell/7ffb24c95440c0fd67d58641824d2473 to your computer and use it in GitHub Desktop.
Retry extension for JUnit 5
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.springer.oscar.test | |
import com.springer.oscar.test.RetryExtension.Companion.TEST_PASSED | |
import org.junit.AssumptionViolatedException | |
import org.junit.jupiter.api.TestTemplate | |
import org.junit.jupiter.api.extension.* | |
import org.junit.platform.commons.util.AnnotationUtils.findAnnotation | |
import org.junit.platform.commons.util.AnnotationUtils.isAnnotated | |
import org.opentest4j.TestAbortedException | |
import org.opentest4j.TestSkippedException | |
import java.util.* | |
import java.util.Spliterators.spliteratorUnknownSize | |
import java.util.stream.Stream | |
import java.util.stream.StreamSupport.stream | |
class RetryExtension : TestTemplateInvocationContextProvider { | |
companion object { | |
internal const val TEST_PASSED: String = "testPassed" | |
} | |
override fun provideTestTemplateInvocationContexts(context: ExtensionContext): Stream<TestTemplateInvocationContext> { | |
val annotation = findAnnotation(context.requiredTestMethod, RetryTest::class.java) | |
.orElseThrow() { IllegalStateException("The test method must be annotated with @RetryTest") } | |
val spliterator: Spliterator<TestTemplateInvocationContext> = spliteratorUnknownSize( | |
RetryTestTemplateIterator(context.displayName, annotation.maxAttempts, context.getStore(storeNamespace(context))), Spliterator.NONNULL) | |
return stream(spliterator, false) | |
} | |
override fun supportsTestTemplate(context: ExtensionContext): Boolean = isAnnotated(context.testMethod, RetryTest::class.java) | |
} | |
internal fun storeNamespace(context: ExtensionContext) = | |
ExtensionContext.Namespace.create(context.requiredTestClass.name, context.requiredTestMethod.name) | |
internal class RetryTestTemplateIterator(private val displayName: String, | |
private val maxAttempts: Int, | |
private val store: ExtensionContext.Store) : Iterator<TestTemplateInvocationContext> { | |
companion object { | |
private const val BASE_DELAY = 1000L | |
} | |
private var currentAttempt = 0 | |
override fun hasNext(): Boolean = currentAttempt < maxAttempts && store.get(TEST_PASSED) != true | |
override fun next(): TestTemplateInvocationContext { | |
if (hasNext()) { | |
val delay = BASE_DELAY * currentAttempt | |
if (delay > 0) { | |
Thread.sleep(BASE_DELAY * currentAttempt) | |
} | |
currentAttempt += 1 | |
return RetryInvocationContext(displayName, currentAttempt, maxAttempts, delay, store) | |
} | |
throw NoSuchElementException() | |
} | |
} | |
internal class RetryInvocationContext(private val displayName: String, | |
private val currentAttempt: Int, | |
private val maxAttempts: Int, | |
private val delay: Long, | |
private val store: ExtensionContext.Store) : TestTemplateInvocationContext { | |
override fun getDisplayName(invocationIndex: Int): String = | |
"$displayName (attempt $currentAttempt/$maxAttempts" + if (delay > 0) { | |
", delayed ${delay}ms)" | |
} else { | |
")" | |
} | |
override fun getAdditionalExtensions(): List<Extension> = listOf( | |
RetryAfterTestExecutionCallback(store), CheckException(currentAttempt, maxAttempts)) | |
} | |
internal class CheckException(private val currentAttempt: Int, | |
private val maxAttempts: Int) : TestExecutionExceptionHandler { | |
override fun handleTestExecutionException(context: ExtensionContext, throwable: Throwable) { | |
if (currentAttempt >= maxAttempts) { | |
throw throwable | |
} else { | |
throw TestAbortedException("Test attempt failed (attempt $currentAttempt/$maxAttempts)") | |
} | |
} | |
} | |
internal class RetryAfterTestExecutionCallback(private val store: ExtensionContext.Store) : AfterTestExecutionCallback { | |
override fun afterTestExecution(context: ExtensionContext) { | |
val testPassed = context.executionException | |
.filter { it.javaClass != AssumptionViolatedException::class.java } | |
.map { false } | |
.orElse(true) | |
store.put(TEST_PASSED, testPassed) | |
} | |
} | |
@Target(AnnotationTarget.FUNCTION, AnnotationTarget.PROPERTY_GETTER, AnnotationTarget.PROPERTY_SETTER, AnnotationTarget.CLASS, AnnotationTarget.FILE) | |
@Retention(AnnotationRetention.RUNTIME) | |
@TestTemplate | |
@ExtendWith(RetryExtension::class) | |
annotation class RetryTest(val maxAttempts: Int = 3) {} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment