Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save hamurcuabi/c462a890b05ba098df95b7d2e86ee33b to your computer and use it in GitHub Desktop.

Select an option

Save hamurcuabi/c462a890b05ba098df95b7d2e86ee33b to your computer and use it in GitHub Desktop.
A custom Detekt rule that enforces rethrowing CancellationException (or calling ensureActive()) inside:, try/catch blocks in suspend functions and runCatching blocks used inside coroutines
import io.gitlab.arturbosch.detekt.api.CodeSmell
import io.gitlab.arturbosch.detekt.api.Config
import io.gitlab.arturbosch.detekt.api.Debt
import io.gitlab.arturbosch.detekt.api.Entity
import io.gitlab.arturbosch.detekt.api.Issue
import io.gitlab.arturbosch.detekt.api.Rule
import io.gitlab.arturbosch.detekt.api.Severity
import io.gitlab.arturbosch.detekt.api.internal.RequiresTypeResolution
import org.jetbrains.kotlin.builtins.isSuspendFunctionType
import org.jetbrains.kotlin.com.intellij.psi.PsiElement
import org.jetbrains.kotlin.descriptors.FunctionDescriptor
import org.jetbrains.kotlin.lexer.KtTokens
import org.jetbrains.kotlin.psi.KtCallExpression
import org.jetbrains.kotlin.psi.KtLambdaExpression
import org.jetbrains.kotlin.psi.KtNamedFunction
import org.jetbrains.kotlin.psi.KtTryExpression
import org.jetbrains.kotlin.resolve.BindingContext
import org.jetbrains.kotlin.types.KotlinType
@RequiresTypeResolution
class CancellationExceptionInCoroutineCatchRule(
config: Config = Config.empty
) : Rule(config) {
companion object {
private const val THROW_KEYWORD = "throw"
private const val ENSURE_ACTIVE = "ensureActive()"
private val SUSPEND_LIKE_FUNCTIONS = setOf("launch", "async", "withContext")
private val VALID_CATCH_TYPES = setOf("Exception", "Throwable")
private const val MESSAGE =
"You must rethrow CancellationException or call ensureActive() inside coroutine try-catch or runCatching blocks."
}
override val issue = Issue(
id = "MissingCancellationExceptionRethrow",
severity = Severity.Defect,
description = "Inside a coroutine or suspend function, always rethrow CancellationException in catch or runCatching blocks.",
debt = Debt.TWENTY_MINS
)
override fun visitTryExpression(expression: KtTryExpression) {
super.visitTryExpression(expression)
if (expression.catchClauses.isEmpty()) return
if (!isInsideCoroutine(expression)) return
expression.catchClauses.forEach { catchClause ->
val paramName = catchClause.catchParameter?.name ?: return@forEach
val paramType = catchClause.catchParameter?.typeReference?.text ?: return@forEach
if (paramType !in VALID_CATCH_TYPES) return@forEach
val catchBodyText = catchClause.catchBody?.text ?: return@forEach
val isHandled = catchBodyText.contains("$THROW_KEYWORD $paramName") ||
(catchBodyText.contains(THROW_KEYWORD) && catchBodyText.contains(
"CancellationException"
)) ||
catchBodyText.contains(ENSURE_ACTIVE)
if (!isHandled) {
report(
CodeSmell(
issue = issue,
entity = Entity.from(catchClause),
message = MESSAGE
)
)
}
}
}
override fun visitCallExpression(expression: KtCallExpression) {
super.visitCallExpression(expression)
if (expression.calleeExpression?.text != "runCatching") return
if (!isInsideCoroutine(expression)) return
val lambdaBodyText = expression
.lambdaArguments
.firstOrNull()
?.getLambdaExpression()
?.bodyExpression
?.text ?: return
val isHandled = lambdaBodyText.contains("throw CancellationException") ||
lambdaBodyText.contains(ENSURE_ACTIVE)
if (!isHandled) {
report(
CodeSmell(
issue = issue,
entity = Entity.from(expression),
message = MESSAGE
)
)
}
}
private fun isInsideCoroutine(element: PsiElement): Boolean {
var current = element
while (true) {
when (current) {
is KtNamedFunction -> {
if (current.hasModifier(KtTokens.SUSPEND_KEYWORD)) return true
}
is KtLambdaExpression -> {
val type = current.getLambdaFunctionType(bindingContext)
if (type?.isSuspendFunctionType == true) return true
val callExpr = current.parent.parent as? KtCallExpression
val calleeName = callExpr?.calleeExpression?.text
if (calleeName in SUSPEND_LIKE_FUNCTIONS) return true
}
}
current = current.parent ?: break
}
return false
}
private fun KtLambdaExpression.getLambdaFunctionType(ctx: BindingContext): KotlinType? {
val functionLiteral = this.functionLiteral
val descriptor = ctx[BindingContext.FUNCTION, functionLiteral]
return (descriptor as? FunctionDescriptor)?.returnType
}
}
import io.gitlab.arturbosch.detekt.api.Config
import io.gitlab.arturbosch.detekt.test.lint
import org.junit.jupiter.api.Assertions
import org.junit.Test
class CancellationExceptionInCoroutineCatchRuleTest {
private val rule = CancellationExceptionInCoroutineCatchRule(
CancellationExceptionInCoroutineCatchRule(Config.empty)
)
@Test
fun `suspend function without rethrow should warn`() {
val code = """
suspend fun suspendWithoutRethrow() {
try {
delay(1)
println("Doing work")
} catch (e: Exception) {
println("Caught: \${'$'}e")
}
}
""".trimIndent()
Assertions.assertEquals(1, rule.lint(code).size)
}
@Test
fun `suspend function with rethrow should not warn`() {
val code = """
suspend fun suspendWithRethrow() {
try {
delay(1)
println("Doing work")
} catch (e: Exception) {
if (e is CancellationException) throw e
println("Caught: \${'$'}e")
}
}
""".trimIndent()
Assertions.assertEquals(0, rule.lint(code).size)
}
@Test
fun `suspend function with ensureActive should not warn`() {
val code = """
suspend fun suspendWithEnsureActive() {
try {
delay(1)
println("Doing work")
} catch (e: Exception) {
ensureActive()
println("Caught: \${'$'}e")
}
}
""".trimIndent()
Assertions.assertEquals(0, rule.lint(code).size)
}
@Test
fun `non suspend function should not warn`() {
val code = """
fun nonSuspendFunction() {
try {
println("Doing work")
} catch (e: Exception) {
println("Caught: \${'$'}e")
}
}
""".trimIndent()
Assertions.assertEquals(0, rule.lint(code).size)
}
@Test
fun `coroutine lambda without rethrow should warn`() {
val code = """
fun coroutineLambdaWithoutRethrow() {
GlobalScope.launch {
delay(1)
try {
println("Coroutine work")
} catch (e: Exception) {
println("Caught: \${'$'}e")
}
}
}
""".trimIndent()
Assertions.assertEquals(1, rule.lint(code).size)
}
@Test
fun `coroutine lambda with rethrow should not warn`() {
val code = """
fun coroutineLambdaWithRethrow() {
GlobalScope.launch {
delay(1)
try {
println("Coroutine work")
} catch (e: Exception) {
if (e is CancellationException) throw e
println("Caught: \${'$'}e")
}
}
}
""".trimIndent()
Assertions.assertEquals(0, rule.lint(code).size)
}
@Test
fun `coroutine lambda with ensureActive should not warn`() {
val code = """
fun coroutineLambdaWithEnsureActive() {
GlobalScope.launch {
delay(1)
try {
println("Coroutine work")
} catch (e: Exception) {
ensureActive()
println("Caught: \${'$'}e")
}
}
}
""".trimIndent()
Assertions.assertEquals(0, rule.lint(code).size)
}
@Test
fun `withContext without cancellation handling should warn`() {
val code = """
suspend fun withContextWithoutRethrow() {
try {
withContext(Dispatchers.IO) {
println("Working in IO")
}
} catch (e: Exception) {
println("Caught: \${'$'}e")
}
}
""".trimIndent()
Assertions.assertEquals(1, rule.lint(code).size)
}
@Test
fun `withContext with rethrow should not warn`() {
val code = """
suspend fun withContextWithRethrow() {
try {
withContext(Dispatchers.IO) {
println("Working in IO")
}
} catch (e: Exception) {
if (e is CancellationException) throw e
println("Caught: \${'$'}e")
}
}
""".trimIndent()
Assertions.assertEquals(0, rule.lint(code).size)
}
@Test
fun `nested coroutine builder without cancellation handling should warn`() {
val code = """
fun nestedCoroutineWithoutRethrow() {
GlobalScope.launch {
delay(1)
try {
withContext(Dispatchers.IO) {
println("Doing work")
}
} catch (e: Exception) {
println("Caught: \${'$'}e")
}
}
}
""".trimIndent()
Assertions.assertEquals(1, rule.lint(code).size)
}
@Test
fun `nested coroutine builder with ensureActive should not warn`() {
val code = """
fun nestedCoroutineWithEnsureActive() {
GlobalScope.launch {
delay(1)
try {
withContext(Dispatchers.IO) {
println("Doing work")
}
} catch (e: Exception) {
ensureActive()
println("Caught: \${'$'}e")
}
}
}
""".trimIndent()
Assertions.assertEquals(0, rule.lint(code).size)
}
@Test
fun `run coroutine builder with runCatching should warn`() {
val code = """
fun nestedCoroutineWithEnsureActive() {
GlobalScope.launch {
delay(1)
runCatching {
withContext(Dispatchers.IO) {
println("Doing work")
}
}
}
}
""".trimIndent()
Assertions.assertEquals(1, rule.lint(code).size)
}
@Test
fun `only runCatching should not warn`() {
val code = """
fun someFunc() {
runCatching {
println("Doing work")
}
}
""".trimIndent()
Assertions.assertEquals(0, rule.lint(code).size)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment