Last active
July 25, 2025 07:22
-
-
Save hamurcuabi/c462a890b05ba098df95b7d2e86ee33b to your computer and use it in GitHub Desktop.
A custom Detekt rule that enforces rethrowing CancellationException (or calling ensureActive()) inside:, try/catch blocks in suspend functions and runCatching blocks used inside coroutines
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import io.gitlab.arturbosch.detekt.api.CodeSmell | |
| import io.gitlab.arturbosch.detekt.api.Config | |
| import io.gitlab.arturbosch.detekt.api.Debt | |
| import io.gitlab.arturbosch.detekt.api.Entity | |
| import io.gitlab.arturbosch.detekt.api.Issue | |
| import io.gitlab.arturbosch.detekt.api.Rule | |
| import io.gitlab.arturbosch.detekt.api.Severity | |
| import io.gitlab.arturbosch.detekt.api.internal.RequiresTypeResolution | |
| import org.jetbrains.kotlin.builtins.isSuspendFunctionType | |
| import org.jetbrains.kotlin.com.intellij.psi.PsiElement | |
| import org.jetbrains.kotlin.descriptors.FunctionDescriptor | |
| import org.jetbrains.kotlin.lexer.KtTokens | |
| import org.jetbrains.kotlin.psi.KtCallExpression | |
| import org.jetbrains.kotlin.psi.KtLambdaExpression | |
| import org.jetbrains.kotlin.psi.KtNamedFunction | |
| import org.jetbrains.kotlin.psi.KtTryExpression | |
| import org.jetbrains.kotlin.resolve.BindingContext | |
| import org.jetbrains.kotlin.types.KotlinType | |
| @RequiresTypeResolution | |
| class CancellationExceptionInCoroutineCatchRule( | |
| config: Config = Config.empty | |
| ) : Rule(config) { | |
| companion object { | |
| private const val THROW_KEYWORD = "throw" | |
| private const val ENSURE_ACTIVE = "ensureActive()" | |
| private val SUSPEND_LIKE_FUNCTIONS = setOf("launch", "async", "withContext") | |
| private val VALID_CATCH_TYPES = setOf("Exception", "Throwable") | |
| private const val MESSAGE = | |
| "You must rethrow CancellationException or call ensureActive() inside coroutine try-catch or runCatching blocks." | |
| } | |
| override val issue = Issue( | |
| id = "MissingCancellationExceptionRethrow", | |
| severity = Severity.Defect, | |
| description = "Inside a coroutine or suspend function, always rethrow CancellationException in catch or runCatching blocks.", | |
| debt = Debt.TWENTY_MINS | |
| ) | |
| override fun visitTryExpression(expression: KtTryExpression) { | |
| super.visitTryExpression(expression) | |
| if (expression.catchClauses.isEmpty()) return | |
| if (!isInsideCoroutine(expression)) return | |
| expression.catchClauses.forEach { catchClause -> | |
| val paramName = catchClause.catchParameter?.name ?: return@forEach | |
| val paramType = catchClause.catchParameter?.typeReference?.text ?: return@forEach | |
| if (paramType !in VALID_CATCH_TYPES) return@forEach | |
| val catchBodyText = catchClause.catchBody?.text ?: return@forEach | |
| val isHandled = catchBodyText.contains("$THROW_KEYWORD $paramName") || | |
| (catchBodyText.contains(THROW_KEYWORD) && catchBodyText.contains( | |
| "CancellationException" | |
| )) || | |
| catchBodyText.contains(ENSURE_ACTIVE) | |
| if (!isHandled) { | |
| report( | |
| CodeSmell( | |
| issue = issue, | |
| entity = Entity.from(catchClause), | |
| message = MESSAGE | |
| ) | |
| ) | |
| } | |
| } | |
| } | |
| override fun visitCallExpression(expression: KtCallExpression) { | |
| super.visitCallExpression(expression) | |
| if (expression.calleeExpression?.text != "runCatching") return | |
| if (!isInsideCoroutine(expression)) return | |
| val lambdaBodyText = expression | |
| .lambdaArguments | |
| .firstOrNull() | |
| ?.getLambdaExpression() | |
| ?.bodyExpression | |
| ?.text ?: return | |
| val isHandled = lambdaBodyText.contains("throw CancellationException") || | |
| lambdaBodyText.contains(ENSURE_ACTIVE) | |
| if (!isHandled) { | |
| report( | |
| CodeSmell( | |
| issue = issue, | |
| entity = Entity.from(expression), | |
| message = MESSAGE | |
| ) | |
| ) | |
| } | |
| } | |
| private fun isInsideCoroutine(element: PsiElement): Boolean { | |
| var current = element | |
| while (true) { | |
| when (current) { | |
| is KtNamedFunction -> { | |
| if (current.hasModifier(KtTokens.SUSPEND_KEYWORD)) return true | |
| } | |
| is KtLambdaExpression -> { | |
| val type = current.getLambdaFunctionType(bindingContext) | |
| if (type?.isSuspendFunctionType == true) return true | |
| val callExpr = current.parent.parent as? KtCallExpression | |
| val calleeName = callExpr?.calleeExpression?.text | |
| if (calleeName in SUSPEND_LIKE_FUNCTIONS) return true | |
| } | |
| } | |
| current = current.parent ?: break | |
| } | |
| return false | |
| } | |
| private fun KtLambdaExpression.getLambdaFunctionType(ctx: BindingContext): KotlinType? { | |
| val functionLiteral = this.functionLiteral | |
| val descriptor = ctx[BindingContext.FUNCTION, functionLiteral] | |
| return (descriptor as? FunctionDescriptor)?.returnType | |
| } | |
| } | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import io.gitlab.arturbosch.detekt.api.Config | |
| import io.gitlab.arturbosch.detekt.test.lint | |
| import org.junit.jupiter.api.Assertions | |
| import org.junit.Test | |
| class CancellationExceptionInCoroutineCatchRuleTest { | |
| private val rule = CancellationExceptionInCoroutineCatchRule( | |
| CancellationExceptionInCoroutineCatchRule(Config.empty) | |
| ) | |
| @Test | |
| fun `suspend function without rethrow should warn`() { | |
| val code = """ | |
| suspend fun suspendWithoutRethrow() { | |
| try { | |
| delay(1) | |
| println("Doing work") | |
| } catch (e: Exception) { | |
| println("Caught: \${'$'}e") | |
| } | |
| } | |
| """.trimIndent() | |
| Assertions.assertEquals(1, rule.lint(code).size) | |
| } | |
| @Test | |
| fun `suspend function with rethrow should not warn`() { | |
| val code = """ | |
| suspend fun suspendWithRethrow() { | |
| try { | |
| delay(1) | |
| println("Doing work") | |
| } catch (e: Exception) { | |
| if (e is CancellationException) throw e | |
| println("Caught: \${'$'}e") | |
| } | |
| } | |
| """.trimIndent() | |
| Assertions.assertEquals(0, rule.lint(code).size) | |
| } | |
| @Test | |
| fun `suspend function with ensureActive should not warn`() { | |
| val code = """ | |
| suspend fun suspendWithEnsureActive() { | |
| try { | |
| delay(1) | |
| println("Doing work") | |
| } catch (e: Exception) { | |
| ensureActive() | |
| println("Caught: \${'$'}e") | |
| } | |
| } | |
| """.trimIndent() | |
| Assertions.assertEquals(0, rule.lint(code).size) | |
| } | |
| @Test | |
| fun `non suspend function should not warn`() { | |
| val code = """ | |
| fun nonSuspendFunction() { | |
| try { | |
| println("Doing work") | |
| } catch (e: Exception) { | |
| println("Caught: \${'$'}e") | |
| } | |
| } | |
| """.trimIndent() | |
| Assertions.assertEquals(0, rule.lint(code).size) | |
| } | |
| @Test | |
| fun `coroutine lambda without rethrow should warn`() { | |
| val code = """ | |
| fun coroutineLambdaWithoutRethrow() { | |
| GlobalScope.launch { | |
| delay(1) | |
| try { | |
| println("Coroutine work") | |
| } catch (e: Exception) { | |
| println("Caught: \${'$'}e") | |
| } | |
| } | |
| } | |
| """.trimIndent() | |
| Assertions.assertEquals(1, rule.lint(code).size) | |
| } | |
| @Test | |
| fun `coroutine lambda with rethrow should not warn`() { | |
| val code = """ | |
| fun coroutineLambdaWithRethrow() { | |
| GlobalScope.launch { | |
| delay(1) | |
| try { | |
| println("Coroutine work") | |
| } catch (e: Exception) { | |
| if (e is CancellationException) throw e | |
| println("Caught: \${'$'}e") | |
| } | |
| } | |
| } | |
| """.trimIndent() | |
| Assertions.assertEquals(0, rule.lint(code).size) | |
| } | |
| @Test | |
| fun `coroutine lambda with ensureActive should not warn`() { | |
| val code = """ | |
| fun coroutineLambdaWithEnsureActive() { | |
| GlobalScope.launch { | |
| delay(1) | |
| try { | |
| println("Coroutine work") | |
| } catch (e: Exception) { | |
| ensureActive() | |
| println("Caught: \${'$'}e") | |
| } | |
| } | |
| } | |
| """.trimIndent() | |
| Assertions.assertEquals(0, rule.lint(code).size) | |
| } | |
| @Test | |
| fun `withContext without cancellation handling should warn`() { | |
| val code = """ | |
| suspend fun withContextWithoutRethrow() { | |
| try { | |
| withContext(Dispatchers.IO) { | |
| println("Working in IO") | |
| } | |
| } catch (e: Exception) { | |
| println("Caught: \${'$'}e") | |
| } | |
| } | |
| """.trimIndent() | |
| Assertions.assertEquals(1, rule.lint(code).size) | |
| } | |
| @Test | |
| fun `withContext with rethrow should not warn`() { | |
| val code = """ | |
| suspend fun withContextWithRethrow() { | |
| try { | |
| withContext(Dispatchers.IO) { | |
| println("Working in IO") | |
| } | |
| } catch (e: Exception) { | |
| if (e is CancellationException) throw e | |
| println("Caught: \${'$'}e") | |
| } | |
| } | |
| """.trimIndent() | |
| Assertions.assertEquals(0, rule.lint(code).size) | |
| } | |
| @Test | |
| fun `nested coroutine builder without cancellation handling should warn`() { | |
| val code = """ | |
| fun nestedCoroutineWithoutRethrow() { | |
| GlobalScope.launch { | |
| delay(1) | |
| try { | |
| withContext(Dispatchers.IO) { | |
| println("Doing work") | |
| } | |
| } catch (e: Exception) { | |
| println("Caught: \${'$'}e") | |
| } | |
| } | |
| } | |
| """.trimIndent() | |
| Assertions.assertEquals(1, rule.lint(code).size) | |
| } | |
| @Test | |
| fun `nested coroutine builder with ensureActive should not warn`() { | |
| val code = """ | |
| fun nestedCoroutineWithEnsureActive() { | |
| GlobalScope.launch { | |
| delay(1) | |
| try { | |
| withContext(Dispatchers.IO) { | |
| println("Doing work") | |
| } | |
| } catch (e: Exception) { | |
| ensureActive() | |
| println("Caught: \${'$'}e") | |
| } | |
| } | |
| } | |
| """.trimIndent() | |
| Assertions.assertEquals(0, rule.lint(code).size) | |
| } | |
| @Test | |
| fun `run coroutine builder with runCatching should warn`() { | |
| val code = """ | |
| fun nestedCoroutineWithEnsureActive() { | |
| GlobalScope.launch { | |
| delay(1) | |
| runCatching { | |
| withContext(Dispatchers.IO) { | |
| println("Doing work") | |
| } | |
| } | |
| } | |
| } | |
| """.trimIndent() | |
| Assertions.assertEquals(1, rule.lint(code).size) | |
| } | |
| @Test | |
| fun `only runCatching should not warn`() { | |
| val code = """ | |
| fun someFunc() { | |
| runCatching { | |
| println("Doing work") | |
| } | |
| } | |
| """.trimIndent() | |
| Assertions.assertEquals(0, rule.lint(code).size) | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment