Last active
January 3, 2025 14:30
-
-
Save SecretX33/8f1e37ec18af5394cad843a5ae1c0503 to your computer and use it in GitHub Desktop.
CoroutineScheduledExecutorService - Bridge between Kotlin Coroutines and Java ScheduledExecutorService
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import arrow.core.nonFatalOrThrow | |
import kotlinx.coroutines.CoroutineScope | |
import kotlinx.coroutines.Job | |
import kotlinx.coroutines.coroutineScope | |
import kotlinx.coroutines.delay | |
import kotlinx.coroutines.isActive | |
import kotlinx.coroutines.launch | |
import org.slf4j.Logger | |
import org.slf4j.LoggerFactory | |
import java.util.concurrent.AbstractExecutorService | |
import java.util.concurrent.Callable | |
import java.util.concurrent.CompletableFuture | |
import java.util.concurrent.ConcurrentHashMap | |
import java.util.concurrent.Delayed | |
import java.util.concurrent.RejectedExecutionException | |
import java.util.concurrent.RunnableFuture | |
import java.util.concurrent.RunnableScheduledFuture | |
import java.util.concurrent.ScheduledExecutorService | |
import java.util.concurrent.TimeUnit | |
import java.util.concurrent.atomic.AtomicBoolean | |
import java.util.concurrent.atomic.AtomicInteger | |
import java.util.concurrent.atomic.AtomicReference | |
import kotlin.coroutines.cancellation.CancellationException | |
import kotlin.time.measureTime | |
class CoroutineScheduledExecutorService( | |
private val coroutineScope: CoroutineScope, | |
) : AbstractExecutorService(), ScheduledExecutorService { | |
private val isShutdown = AtomicReference(false) | |
private val taskId = AtomicInteger(0) | |
private val scheduledTasks = ConcurrentHashMap.newKeySet<CoroutineScheduledFuture<*>>() | |
override fun execute(task: Runnable) { | |
when (task) { | |
is CoroutineScheduledFuture<*> -> task.run() // No dispatch needed | |
else -> schedule(task, 0, TimeUnit.MILLISECONDS) | |
} | |
} | |
override fun isShutdown(): Boolean = isShutdown.get() | |
override fun isTerminated(): Boolean = isShutdown() && scheduledTasks.isEmpty() | |
override fun submit(task: Runnable): CoroutineScheduledFuture<*> = submit(task.toCallable()) | |
override fun <T> submit(task: Runnable, result: T): CoroutineScheduledFuture<T> = | |
submit(Callable { task.run(); result }) | |
override fun <T> submit(task: Callable<T>): CoroutineScheduledFuture<T> = | |
schedule(task, 0, TimeUnit.MILLISECONDS) | |
override fun schedule( | |
task: Runnable, | |
timeout: Long, | |
unit: TimeUnit, | |
): CoroutineScheduledFuture<*> = schedule(task.toCallable(), timeout, unit) | |
override fun <T> schedule( | |
task: Callable<T>, | |
timeout: Long, | |
unit: TimeUnit, | |
): CoroutineScheduledFuture<T> = executeTask { | |
delay(unit.toMillis(timeout)) | |
complete(task.call()) | |
} | |
override fun scheduleAtFixedRate( | |
task: Runnable, | |
initialDelay: Long, | |
period: Long, | |
unit: TimeUnit, | |
): CoroutineScheduledFuture<*> = executeTask<Unit>(isPeriodic = true) { | |
delay(unit.toMillis(initialDelay)) | |
while (isActive) { | |
val taskElapsedTime = measureTime { | |
try { | |
task.run() | |
} catch (e: Throwable) { | |
if (e is CancellationException && !isActive) { | |
log.warn("Scheduled task was cancelled", e) | |
return@executeTask | |
} | |
log.error("Error running scheduled task", e) | |
e.nonFatalOrThrow() | |
} | |
} | |
delay((unit.toMillis(period) - taskElapsedTime.inWholeMilliseconds).coerceAtLeast(0)) | |
} | |
} | |
override fun scheduleWithFixedDelay( | |
task: Runnable, | |
initialDelay: Long, | |
period: Long, | |
unit: TimeUnit, | |
): CoroutineScheduledFuture<*> = executeTask<Unit>(isPeriodic = true) { | |
delay(unit.toMillis(initialDelay)) | |
while (isActive) { | |
try { | |
task.run() | |
} catch (e: Throwable) { | |
if (e is CancellationException && !isActive) { | |
log.warn("Scheduled task was cancelled", e) | |
return@executeTask | |
} | |
log.error("Error running scheduled task '$taskId'", e) | |
e.nonFatalOrThrow() | |
} | |
delay(unit.toMillis(period)) | |
} | |
} | |
override fun shutdown() { | |
if (!isShutdown.compareAndSet(false, true)) return | |
synchronized(scheduledTasks) { | |
scheduledTasks.forEach { it.cancel(false) } | |
scheduledTasks.clear() | |
} | |
} | |
override fun shutdownNow(): List<Runnable> { | |
shutdown() | |
return emptyList() | |
} | |
override fun awaitTermination(timeout: Long, unit: TimeUnit): Boolean { | |
val deadline = System.nanoTime() + unit.toNanos(timeout) | |
while (System.nanoTime() < deadline && !isTerminated()) { | |
Thread.yield() | |
Thread.sleep(10) // Yield CPU briefly | |
} | |
return isTerminated() | |
} | |
override fun <T> newTaskFor(callable: Callable<T>): RunnableFuture<T> = | |
queueTask { complete(callable.call()) } | |
override fun <T> newTaskFor(runnable: Runnable, value: T): RunnableFuture<T> = | |
queueTask { runnable.run(); complete(value) } | |
private fun <T> executeTask( | |
isPeriodic: Boolean = false, | |
execute: suspend CoroutineTaskScheduleDsl<T>.() -> Unit, | |
): CoroutineScheduledFuture<T> = queueTask(isPeriodic, execute).apply { run() } | |
private fun <T> queueTask( | |
isPeriodic: Boolean = false, | |
execute: suspend CoroutineTaskScheduleDsl<T>.() -> Unit, | |
): CoroutineScheduledFuture<T> { | |
ensureNotShutdown() | |
val startTask = { scheduledFuture: CoroutineScheduledFuture<T> -> | |
coroutineScope.launch { | |
if (scheduledFuture !in scheduledTasks || scheduledFuture.isDone) return@launch | |
coroutineScope { | |
val scheduleDsl = CoroutineTaskScheduleDsl( | |
taskId = scheduledFuture.id, | |
scheduleDsl = ScheduleDsl<T> { scheduledFuture.complete(it) }, | |
coroutineScope = this, | |
) | |
scheduleDsl.execute() | |
} | |
}.also { scheduledFuture.task = it }.invokeOnCompletion { | |
when { | |
it is CancellationException -> scheduledFuture.cancel(true) | |
it != null -> scheduledFuture.completeExceptionally(it) | |
else -> {} | |
} | |
scheduledTasks.remove(scheduledFuture) | |
} | |
} | |
val scheduledFuture = CoroutineScheduledFuture( | |
id = getNextTaskId(), | |
isPeriodic = isPeriodic, | |
run = { startTask(it) } | |
).also { | |
scheduledTasks += it | |
} | |
return scheduledFuture | |
} | |
private fun getNextTaskId(): Int = taskId.getAndIncrement() | |
private fun Runnable.toCallable(): Callable<Unit> = Callable { run() } | |
private fun ensureNotShutdown() { | |
if (isShutdown()) throw RejectedExecutionException("Executor has been shut down") | |
} | |
private fun interface ScheduleDsl<in T> { | |
fun complete(value: T) | |
} | |
private class CoroutineTaskScheduleDsl<T>( | |
val taskId: Int, | |
val scheduleDsl: ScheduleDsl<T>, | |
val coroutineScope: CoroutineScope, | |
) : ScheduleDsl<T> by scheduleDsl, CoroutineScope by coroutineScope | |
data class CoroutineScheduledFuture<T>( | |
val id: Int, | |
private val isPeriodic: Boolean, | |
private val run: (CoroutineScheduledFuture<T>) -> Unit, // Responsible for completing this future | |
) : CompletableFuture<T>(), RunnableScheduledFuture<T> { | |
lateinit var task: Job | |
private val isStarted = AtomicBoolean(false) | |
override fun run() { | |
if (!isStarted.compareAndSet(false, true)) return | |
run(this) | |
} | |
override fun isPeriodic(): Boolean = isPeriodic | |
override fun complete(value: T?): Boolean { | |
isStarted.set(true) | |
return super.complete(value) | |
} | |
override fun completeExceptionally(ex: Throwable?): Boolean { | |
isStarted.set(true) | |
return super.completeExceptionally(ex) | |
} | |
override fun cancel(mayInterruptIfRunning: Boolean): Boolean { | |
isStarted.set(true) | |
if (task.isCompleted) return false | |
val cancellationException = CancellationException() | |
if (completeExceptionally(cancellationException)) { | |
task.cancel(cancellationException) | |
return true | |
} | |
return false | |
} | |
override fun getDelay(unit: TimeUnit): Long = 0 | |
override fun compareTo(other: Delayed?): Int = 0 | |
} | |
private companion object { | |
val log: Logger = LoggerFactory.getLogger(CoroutineScheduledExecutorService::class.java) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment