Last active
January 6, 2017 06:04
-
-
Save chadselph/f0e1559ecc2b178b83ab02dc50fc41ca to your computer and use it in GitHub Desktop.
tiny slick-based distributed task queue prototype
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.sql.Timestamp | |
import java.time.{Duration, Instant} | |
import slick.driver.JdbcProfile | |
import slick.profile.SqlProfile.ColumnOption.Nullable | |
import scala.concurrent.{ExecutionContext, Future} | |
import scala.util.{Failure, Success} | |
/** | |
* Created by chad on 1/5/17. | |
*/ | |
class TaskQueue[Driver <: JdbcProfile](tableName: String = "tasks", | |
val driver: Driver, | |
database: Driver#Backend#Database) { | |
import driver.api._ | |
// type Instant = String | |
type TaskName = String | |
type ShouldRun = Option[TaskRun] => Boolean | |
/** | |
* Use this if you want your tasks to start every `duration` without caring whether or | |
* not the previous one has finished. Keep in mind, if your task takes longer than `duration` | |
* you run the risk of them piling them on top of each other. | |
* @param duration How long to wait after the previous worker finished before trying again | |
*/ | |
case class HasNotStartedInLast(duration: Duration) extends ShouldRun { | |
override def apply(lastRun: Option[TaskRun]): Boolean = { | |
lastRun.fold(ifEmpty = true)(lastRun => | |
wasMoreThanDurationAgo(duration, lastRun.startedAt)) | |
} | |
} | |
/** | |
* You can use this class to avoid your tasks from stacking on top of each other, i.e. if the duration | |
* of the task takes longer than the wait time. | |
* They will only start when the previous runner has finished or we think they are dead. | |
* @param duration how long to wait after the previous worker finisher before running again. | |
* @param expiration how long to wait after the startedAt time before we decide to run again anyway. | |
*/ | |
case class HasNotFinishedInLast(duration: Duration, expiration: Duration) | |
extends ShouldRun { | |
override def apply(lastRun: Option[TaskRun]): Boolean = lastRun match { | |
case None => true | |
case Some(TaskRun(_, _, _, started, None)) => | |
wasMoreThanDurationAgo(expiration, started) | |
case Some(TaskRun(_, _, _, _, Some(finished))) => | |
wasMoreThanDurationAgo(duration, finished) | |
} | |
} | |
sealed trait RunResult | |
case class StartRun(previousRun: Option[TaskRun], thisRun: TaskRun) | |
extends RunResult | |
case class SkipRun(previousRun: Option[TaskRun]) extends RunResult | |
case class InsertFailed(cause: Throwable) extends RunResult | |
case class TaskRun(name: String, | |
runId: Int, | |
worker: String, | |
startedAt: Instant = Instant.now, | |
finishedAt: Option[Instant] = None) | |
class TaskRuns(tag: Tag) extends Table[TaskRun](tag, tableName) { | |
implicit val instantColumnType: BaseColumnType[Instant] = | |
MappedColumnType.base[Instant, Timestamp](Timestamp.from, _.toInstant) | |
def name = column[TaskName]("name") | |
def runId = column[Int]("run_id") | |
def worker = column[String]("worker") | |
def startedAt = column[Instant]("started_at") | |
def finishedAt = column[Instant]("finished_at", Nullable) | |
def * = | |
(name, runId, worker, startedAt, finishedAt.?) <> (TaskRun.tupled, TaskRun.unapply) | |
def pk = primaryKey("pk_name_run_id", (name, runId)) | |
} | |
val tasks = TableQuery[TaskRuns] | |
def markFinished(taskRun: TaskRun) = | |
tasks.update(taskRun.copy(finishedAt = Some(Instant.now()))) | |
def startTaskIf[A](shouldRun: ShouldRun, taskName: TaskName, worker: String)( | |
doAction: (StartRun) => Future[A])(implicit ec: ExecutionContext) = { | |
val latestRunId = tasks.filter(_.name === taskName).map(_.runId).max | |
val latestRun = tasks | |
.filter(_.name === taskName) | |
.filter(_.runId === latestRunId) | |
.result | |
.headOption | |
val insertTaskRun = latestRun.map { | |
case last if !shouldRun(last) => | |
SkipRun(last) | |
case None => | |
StartRun(None, TaskRun(taskName, 1, worker)) | |
case Some(last) => | |
StartRun(Some(last), TaskRun(taskName, last.runId + 1, worker)) | |
}.flatMap { | |
case r @ StartRun(_, thisRun) => | |
(tasks += thisRun).asTry.map { | |
case Failure(ex) => InsertFailed(ex) | |
case _ => r | |
} | |
case r => DBIO.successful(r) | |
} | |
database.run(insertTaskRun).flatMap { | |
case sr @ StartRun(_, thisRun) => | |
for { | |
_ <- doAction(sr).map(Success.apply).recover { | |
// convert the failure case into Success so we still mark the task as finished. | |
case ex => Failure(ex) | |
} | |
_ <- database.run(markFinished(thisRun)) | |
} yield sr | |
case other => Future.successful(other) | |
} | |
} | |
/** | |
* Helper to see an [[Instant]] was more than [[Duration]] ago. | |
* i.e. "was 2017-01-01 13:13 more than 5 hours ago?" | |
*/ | |
private def wasMoreThanDurationAgo(duration: Duration, | |
instant: Instant): Boolean = { | |
instant.plus(duration).isAfter(Instant.now()) | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.time.Duration | |
import slick.driver.H2Driver | |
import slick.driver.H2Driver.api._ | |
import scala.concurrent.Await | |
import scala.concurrent.ExecutionContext.Implicits.global | |
import scala.concurrent.duration._ | |
/** | |
* Created by chad on 1/5/17. | |
*/ | |
object ExampleUsage extends App { | |
val db = Database.forURL("jdbc:h2:mem:testdb;DB_CLOSE_DELAY=-1") | |
val queue = new TaskQueue[H2Driver]("tasks", H2Driver, db) | |
val createDb = db.run(queue.tasks.schema.create) | |
import queue._ | |
val done = | |
createDb.flatMap { _ => | |
queue.startTaskIf(HasNotStartedInLast(Duration.ofMinutes(10)), "print-hello", "worker-1") { start => | |
println(s"Last print was at $start.") | |
println("HELLO") | |
throw new Exception("SDF") | |
} | |
} | |
println(Await.result(done, 10.seconds)) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment