Skip to content

Instantly share code, notes, and snippets.

@ckampfe
Last active February 24, 2025 20:05
Show Gist options
  • Save ckampfe/5ae7ac4423851cddda2c8aafb4d3594c to your computer and use it in GitHub Desktop.
Save ckampfe/5ae7ac4423851cddda2c8aafb4d3594c to your computer and use it in GitHub Desktop.
use const_format::formatc;
use serde::de::DeserializeOwned;
use sqlx::{Acquire, Pool, Sqlite, SqliteConnection};
use std::fmt::Debug;
use std::str::FromStr;
use thiserror::Error;
use tokio::task::JoinError;
use worker::Worker;
mod worker;
pub struct Options {
database: String,
max_attempts: i64,
in_memory: bool,
}
impl Default for Options {
fn default() -> Self {
Self {
database: "qq.db".to_string(),
max_attempts: 3,
in_memory: false,
}
}
}
#[derive(Debug, Error)]
pub enum Error {
#[error("sqlite")]
Sqlite(#[from] sqlx::Error),
#[error("serializing")]
Serialization(#[from] serde_json::Error),
#[error("sljksdfj")]
Join(#[from] JoinError),
}
#[derive(sqlx::FromRow)]
pub struct Job {
id: i64,
args: String,
queue: String,
attempts: i64,
inserted_at: String,
updated_at: String,
}
#[derive(Debug)]
pub struct NewJob<W>
where
W: Worker,
{
args: W::T,
}
pub struct Qq {
pool: Pool<Sqlite>,
max_attempts: i64,
}
impl Qq {
pub async fn new(options: Options) -> Result<Qq, Error> {
assert!(options.max_attempts > 0);
let db_name = if options.in_memory {
"sqlite::memory:".to_string()
} else {
options.database
};
let opts = sqlx::sqlite::SqliteConnectOptions::from_str(&db_name)?
.busy_timeout(std::time::Duration::from_secs(5))
.journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
.create_if_missing(true)
.foreign_keys(true)
.in_memory(options.in_memory);
let pool = sqlx::SqlitePool::connect_with(opts).await?;
let mut conn = pool.acquire().await?;
let mut tx = conn.begin().await?;
sqlx::query(
"
create table if not exists qq_jobs (
id integer primary key,
args text not null,
queue text not null,
status integer not null default 0,
attempts integer not null default 0,
inserted_at datetime not null default(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
updated_at datetime not null default(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
)",
)
.bind(options.max_attempts)
.execute(&mut *tx)
.await?;
sqlx::query(
"
create index if not exists queue_idx on qq_jobs(queue);
",
)
.execute(&mut *tx)
.await?;
sqlx::query(
"
create index if not exists status_idx on qq_jobs(status);
",
)
.execute(&mut *tx)
.await?;
tx.commit().await?;
Ok(Qq {
pool,
max_attempts: options.max_attempts,
})
}
pub async fn add_worker<W>(self, concurrency: usize, state: W::S) -> Self
where
W: Worker + 'static,
{
assert!(concurrency > 0);
// TODO wire up control of worker tasks
let queue_name = W::queue().to_string();
// todo store in self
// todo run N tasks as configured
// todo somehow get access to connection pool via Arc<Mutex> and checkout conns as needed
for i in 0..concurrency {
let pool = self.pool.clone();
let queue_name = queue_name.clone();
let state = state.clone();
tokio::spawn(async move {
println!("started task {i}");
let queue_name = queue_name;
let pool = pool;
loop {
let mut conn = pool.acquire().await.unwrap();
if let Some(args) = Self::try_receive(&mut conn, &queue_name, self.max_attempts)
.await
.unwrap()
{
let job_id = args.job_id;
println!("processing job {job_id} in task {i}");
match W::perform(args, state.clone()).await {
Ok(()) => {
sqlx::query(formatc!(
"
update qq_jobs
set status = {}
where id = ?
",
JobStatus::Done as isize
))
.bind(job_id)
.execute(&mut *conn)
.await
.unwrap();
}
Err(e) => match e {
worker::Error::Cancel => {
sqlx::query(formatc!(
"
update qq_jobs
set status = {}
where id = ?
",
JobStatus::Done as isize
))
.bind(job_id)
.execute(&mut *conn)
.await
.unwrap();
}
worker::Error::Error => {
sqlx::query(formatc!(
"
update qq_jobs
set status = {}
where id = ?
",
JobStatus::Ready as isize
))
.bind(job_id)
.execute(&mut *conn)
.await
.unwrap();
}
},
}
}
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
});
}
self
}
// #[tracing::instrument(skip(self))]
pub async fn enqueue<W>(&self, job: NewJob<W>) -> Result<(), Error>
where
// T: Serialize + Debug,
W: Worker,
{
let args_json = serde_json::to_string(&job.args)?;
let mut conn = self.pool.acquire().await?;
sqlx::query("insert into qq_jobs (args, queue) values (?, ?)")
.bind(&args_json)
.bind(W::queue())
.execute(&mut *conn)
.await?;
println!("inserted");
Ok(())
}
#[tracing::instrument(skip(conn, max_attempts))]
async fn try_receive<T>(
conn: &mut SqliteConnection,
queue: &str,
max_attempts: i64,
) -> Result<Option<worker::Args<T>>, Error>
where
T: DeserializeOwned,
{
let job: Option<Job> = sqlx::query_as(formatc!(
"
update qq_jobs
set
attempts = attempts + 1,
status = {},
updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
where id = (
select
id
from qq_jobs
where queue = ?
and status = {}
and attempts <= ?
order by inserted_at asc
limit 1
)
returning
id,
args,
queue,
attempts,
inserted_at,
updated_at;
",
JobStatus::Locked as isize,
JobStatus::Ready as isize,
))
.bind(queue)
.bind(max_attempts)
.fetch_optional(&mut *conn)
.await?;
if let Some(job) = job {
let args: T = serde_json::from_str(&job.args)?;
Ok(Some(worker::Args {
args,
job_id: job.id,
attempts: job.attempts,
inserted_at: job.inserted_at,
updated_at: job.updated_at,
}))
} else {
Ok(None)
}
}
}
enum JobStatus {
Ready = 0,
Locked = 1,
Done = 2,
Failed = 3,
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use crate::worker::{Args, Worker};
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;
#[tokio::test]
async fn workers() {
#[derive(Debug, Deserialize, Serialize)]
struct MyArgs {
foo: i64,
}
#[derive(Serialize, Deserialize, Debug)]
struct MyWorker;
// #[derive(Serialize, Deserialize, Debug)]
// struct MyWorker2;
struct MyState {
i: usize,
}
impl Worker for MyWorker {
type T = MyArgs;
type S = Arc<Mutex<MyState>>;
async fn perform(job: Args<Self::T>, state: Self::S) -> Result<(), worker::Error> {
let mut state = state.lock().await;
println!("XXXXXXXXXXXXXXXXXXXXX job: {:?}", job);
println!("XXXXXXXXXXXXXXXXXXXXX old state: {}", state.i);
state.i += 1;
println!("XXXXXXXXXXXXXXXXXXXXX new state: {}", state.i);
drop(state);
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
Ok(())
}
fn queue() -> &'static str {
"abc"
}
}
let mut options = Options::default();
options.in_memory = true;
let q: Qq = Qq::new(options)
.await
.unwrap()
.add_worker::<MyWorker>(5, Arc::new(Mutex::new(MyState { i: 8 })))
.await;
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
q.enqueue(MyWorker::new(MyArgs { foo: 1 })).await.unwrap();
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
}
}
use serde::{Serialize, de::DeserializeOwned};
use thiserror::Error;
use crate::NewJob;
/// all workers implement this trait
pub trait Worker {
type T: Serialize + DeserializeOwned + Send;
type S: Clone + Send + Sync;
fn perform(
args: Args<Self::T>,
state: Self::S,
) -> impl std::future::Future<Output = Result<(), Error>> + Send + Sync + 'static
where
Self: Sized;
fn queue() -> &'static str;
// fn name() -> &'static str {
// std::any::type_name::<Self>()
// }
fn new(args: Self::T) -> NewJob<Self>
where
Self: Sized,
{
NewJob { args }
}
}
/// a blank default state object
pub struct NoState;
/// the args that are passed to a worker
#[derive(Debug)]
pub struct Args<T> {
pub args: T,
pub(crate) job_id: i64,
pub(crate) attempts: i64,
pub(crate) inserted_at: String,
pub(crate) updated_at: String,
}
/// the worker must return this if it ends in a non-ok state
#[derive(Error, Debug)]
pub enum Error {
#[error("cancel")]
Cancel,
#[error("error")]
Error,
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment