Skip to content

Instantly share code, notes, and snippets.

@SecretX33
Last active January 3, 2025 14:30
Show Gist options
  • Save SecretX33/8f1e37ec18af5394cad843a5ae1c0503 to your computer and use it in GitHub Desktop.
Save SecretX33/8f1e37ec18af5394cad843a5ae1c0503 to your computer and use it in GitHub Desktop.
CoroutineScheduledExecutorService - Bridge between Kotlin Coroutines and Java ScheduledExecutorService
import arrow.core.nonFatalOrThrow
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Job
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.delay
import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import java.util.concurrent.AbstractExecutorService
import java.util.concurrent.Callable
import java.util.concurrent.CompletableFuture
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.Delayed
import java.util.concurrent.RejectedExecutionException
import java.util.concurrent.RunnableFuture
import java.util.concurrent.RunnableScheduledFuture
import java.util.concurrent.ScheduledExecutorService
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.atomic.AtomicReference
import kotlin.coroutines.cancellation.CancellationException
import kotlin.time.measureTime
class CoroutineScheduledExecutorService(
private val coroutineScope: CoroutineScope,
) : AbstractExecutorService(), ScheduledExecutorService {
private val isShutdown = AtomicReference(false)
private val taskId = AtomicInteger(0)
private val scheduledTasks = ConcurrentHashMap.newKeySet<CoroutineScheduledFuture<*>>()
override fun execute(task: Runnable) {
when (task) {
is CoroutineScheduledFuture<*> -> task.run() // No dispatch needed
else -> schedule(task, 0, TimeUnit.MILLISECONDS)
}
}
override fun isShutdown(): Boolean = isShutdown.get()
override fun isTerminated(): Boolean = isShutdown() && scheduledTasks.isEmpty()
override fun submit(task: Runnable): CoroutineScheduledFuture<*> = submit(task.toCallable())
override fun <T> submit(task: Runnable, result: T): CoroutineScheduledFuture<T> =
submit(Callable { task.run(); result })
override fun <T> submit(task: Callable<T>): CoroutineScheduledFuture<T> =
schedule(task, 0, TimeUnit.MILLISECONDS)
override fun schedule(
task: Runnable,
timeout: Long,
unit: TimeUnit,
): CoroutineScheduledFuture<*> = schedule(task.toCallable(), timeout, unit)
override fun <T> schedule(
task: Callable<T>,
timeout: Long,
unit: TimeUnit,
): CoroutineScheduledFuture<T> = executeTask {
delay(unit.toMillis(timeout))
complete(task.call())
}
override fun scheduleAtFixedRate(
task: Runnable,
initialDelay: Long,
period: Long,
unit: TimeUnit,
): CoroutineScheduledFuture<*> = executeTask<Unit>(isPeriodic = true) {
delay(unit.toMillis(initialDelay))
while (isActive) {
val taskElapsedTime = measureTime {
try {
task.run()
} catch (e: Throwable) {
if (e is CancellationException && !isActive) {
log.warn("Scheduled task was cancelled", e)
return@executeTask
}
log.error("Error running scheduled task", e)
e.nonFatalOrThrow()
}
}
delay((unit.toMillis(period) - taskElapsedTime.inWholeMilliseconds).coerceAtLeast(0))
}
}
override fun scheduleWithFixedDelay(
task: Runnable,
initialDelay: Long,
period: Long,
unit: TimeUnit,
): CoroutineScheduledFuture<*> = executeTask<Unit>(isPeriodic = true) {
delay(unit.toMillis(initialDelay))
while (isActive) {
try {
task.run()
} catch (e: Throwable) {
if (e is CancellationException && !isActive) {
log.warn("Scheduled task was cancelled", e)
return@executeTask
}
log.error("Error running scheduled task '$taskId'", e)
e.nonFatalOrThrow()
}
delay(unit.toMillis(period))
}
}
override fun shutdown() {
if (!isShutdown.compareAndSet(false, true)) return
synchronized(scheduledTasks) {
scheduledTasks.forEach { it.cancel(false) }
scheduledTasks.clear()
}
}
override fun shutdownNow(): List<Runnable> {
shutdown()
return emptyList()
}
override fun awaitTermination(timeout: Long, unit: TimeUnit): Boolean {
val deadline = System.nanoTime() + unit.toNanos(timeout)
while (System.nanoTime() < deadline && !isTerminated()) {
Thread.yield()
Thread.sleep(10) // Yield CPU briefly
}
return isTerminated()
}
override fun <T> newTaskFor(callable: Callable<T>): RunnableFuture<T> =
queueTask { complete(callable.call()) }
override fun <T> newTaskFor(runnable: Runnable, value: T): RunnableFuture<T> =
queueTask { runnable.run(); complete(value) }
private fun <T> executeTask(
isPeriodic: Boolean = false,
execute: suspend CoroutineTaskScheduleDsl<T>.() -> Unit,
): CoroutineScheduledFuture<T> = queueTask(isPeriodic, execute).apply { run() }
private fun <T> queueTask(
isPeriodic: Boolean = false,
execute: suspend CoroutineTaskScheduleDsl<T>.() -> Unit,
): CoroutineScheduledFuture<T> {
ensureNotShutdown()
val startTask = { scheduledFuture: CoroutineScheduledFuture<T> ->
coroutineScope.launch {
if (scheduledFuture !in scheduledTasks || scheduledFuture.isDone) return@launch
coroutineScope {
val scheduleDsl = CoroutineTaskScheduleDsl(
taskId = scheduledFuture.id,
scheduleDsl = ScheduleDsl<T> { scheduledFuture.complete(it) },
coroutineScope = this,
)
scheduleDsl.execute()
}
}.also { scheduledFuture.task = it }.invokeOnCompletion {
when {
it is CancellationException -> scheduledFuture.cancel(true)
it != null -> scheduledFuture.completeExceptionally(it)
else -> {}
}
scheduledTasks.remove(scheduledFuture)
}
}
val scheduledFuture = CoroutineScheduledFuture(
id = getNextTaskId(),
isPeriodic = isPeriodic,
run = { startTask(it) }
).also {
scheduledTasks += it
}
return scheduledFuture
}
private fun getNextTaskId(): Int = taskId.getAndIncrement()
private fun Runnable.toCallable(): Callable<Unit> = Callable { run() }
private fun ensureNotShutdown() {
if (isShutdown()) throw RejectedExecutionException("Executor has been shut down")
}
private fun interface ScheduleDsl<in T> {
fun complete(value: T)
}
private class CoroutineTaskScheduleDsl<T>(
val taskId: Int,
val scheduleDsl: ScheduleDsl<T>,
val coroutineScope: CoroutineScope,
) : ScheduleDsl<T> by scheduleDsl, CoroutineScope by coroutineScope
data class CoroutineScheduledFuture<T>(
val id: Int,
private val isPeriodic: Boolean,
private val run: (CoroutineScheduledFuture<T>) -> Unit, // Responsible for completing this future
) : CompletableFuture<T>(), RunnableScheduledFuture<T> {
lateinit var task: Job
private val isStarted = AtomicBoolean(false)
override fun run() {
if (!isStarted.compareAndSet(false, true)) return
run(this)
}
override fun isPeriodic(): Boolean = isPeriodic
override fun complete(value: T?): Boolean {
isStarted.set(true)
return super.complete(value)
}
override fun completeExceptionally(ex: Throwable?): Boolean {
isStarted.set(true)
return super.completeExceptionally(ex)
}
override fun cancel(mayInterruptIfRunning: Boolean): Boolean {
isStarted.set(true)
if (task.isCompleted) return false
val cancellationException = CancellationException()
if (completeExceptionally(cancellationException)) {
task.cancel(cancellationException)
return true
}
return false
}
override fun getDelay(unit: TimeUnit): Long = 0
override fun compareTo(other: Delayed?): Int = 0
}
private companion object {
val log: Logger = LoggerFactory.getLogger(CoroutineScheduledExecutorService::class.java)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment