Last active
February 24, 2025 20:05
-
-
Save ckampfe/5ae7ac4423851cddda2c8aafb4d3594c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use const_format::formatc; | |
use serde::de::DeserializeOwned; | |
use sqlx::{Acquire, Pool, Sqlite, SqliteConnection}; | |
use std::fmt::Debug; | |
use std::str::FromStr; | |
use thiserror::Error; | |
use tokio::task::JoinError; | |
use worker::Worker; | |
mod worker; | |
pub struct Options { | |
database: String, | |
max_attempts: i64, | |
in_memory: bool, | |
} | |
impl Default for Options { | |
fn default() -> Self { | |
Self { | |
database: "qq.db".to_string(), | |
max_attempts: 3, | |
in_memory: false, | |
} | |
} | |
} | |
#[derive(Debug, Error)] | |
pub enum Error { | |
#[error("sqlite")] | |
Sqlite(#[from] sqlx::Error), | |
#[error("serializing")] | |
Serialization(#[from] serde_json::Error), | |
#[error("sljksdfj")] | |
Join(#[from] JoinError), | |
} | |
#[derive(sqlx::FromRow)] | |
pub struct Job { | |
id: i64, | |
args: String, | |
queue: String, | |
attempts: i64, | |
inserted_at: String, | |
updated_at: String, | |
} | |
#[derive(Debug)] | |
pub struct NewJob<W> | |
where | |
W: Worker, | |
{ | |
args: W::T, | |
} | |
pub struct Qq { | |
pool: Pool<Sqlite>, | |
max_attempts: i64, | |
} | |
impl Qq { | |
pub async fn new(options: Options) -> Result<Qq, Error> { | |
assert!(options.max_attempts > 0); | |
let db_name = if options.in_memory { | |
"sqlite::memory:".to_string() | |
} else { | |
options.database | |
}; | |
let opts = sqlx::sqlite::SqliteConnectOptions::from_str(&db_name)? | |
.busy_timeout(std::time::Duration::from_secs(5)) | |
.journal_mode(sqlx::sqlite::SqliteJournalMode::Wal) | |
.create_if_missing(true) | |
.foreign_keys(true) | |
.in_memory(options.in_memory); | |
let pool = sqlx::SqlitePool::connect_with(opts).await?; | |
let mut conn = pool.acquire().await?; | |
let mut tx = conn.begin().await?; | |
sqlx::query( | |
" | |
create table if not exists qq_jobs ( | |
id integer primary key, | |
args text not null, | |
queue text not null, | |
status integer not null default 0, | |
attempts integer not null default 0, | |
inserted_at datetime not null default(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), | |
updated_at datetime not null default(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) | |
)", | |
) | |
.bind(options.max_attempts) | |
.execute(&mut *tx) | |
.await?; | |
sqlx::query( | |
" | |
create index if not exists queue_idx on qq_jobs(queue); | |
", | |
) | |
.execute(&mut *tx) | |
.await?; | |
sqlx::query( | |
" | |
create index if not exists status_idx on qq_jobs(status); | |
", | |
) | |
.execute(&mut *tx) | |
.await?; | |
tx.commit().await?; | |
Ok(Qq { | |
pool, | |
max_attempts: options.max_attempts, | |
}) | |
} | |
pub async fn add_worker<W>(self, concurrency: usize, state: W::S) -> Self | |
where | |
W: Worker + 'static, | |
{ | |
assert!(concurrency > 0); | |
// TODO wire up control of worker tasks | |
let queue_name = W::queue().to_string(); | |
// todo store in self | |
// todo run N tasks as configured | |
// todo somehow get access to connection pool via Arc<Mutex> and checkout conns as needed | |
for i in 0..concurrency { | |
let pool = self.pool.clone(); | |
let queue_name = queue_name.clone(); | |
let state = state.clone(); | |
tokio::spawn(async move { | |
println!("started task {i}"); | |
let queue_name = queue_name; | |
let pool = pool; | |
loop { | |
let mut conn = pool.acquire().await.unwrap(); | |
if let Some(args) = Self::try_receive(&mut conn, &queue_name, self.max_attempts) | |
.await | |
.unwrap() | |
{ | |
let job_id = args.job_id; | |
println!("processing job {job_id} in task {i}"); | |
match W::perform(args, state.clone()).await { | |
Ok(()) => { | |
sqlx::query(formatc!( | |
" | |
update qq_jobs | |
set status = {} | |
where id = ? | |
", | |
JobStatus::Done as isize | |
)) | |
.bind(job_id) | |
.execute(&mut *conn) | |
.await | |
.unwrap(); | |
} | |
Err(e) => match e { | |
worker::Error::Cancel => { | |
sqlx::query(formatc!( | |
" | |
update qq_jobs | |
set status = {} | |
where id = ? | |
", | |
JobStatus::Done as isize | |
)) | |
.bind(job_id) | |
.execute(&mut *conn) | |
.await | |
.unwrap(); | |
} | |
worker::Error::Error => { | |
sqlx::query(formatc!( | |
" | |
update qq_jobs | |
set status = {} | |
where id = ? | |
", | |
JobStatus::Ready as isize | |
)) | |
.bind(job_id) | |
.execute(&mut *conn) | |
.await | |
.unwrap(); | |
} | |
}, | |
} | |
} | |
tokio::time::sleep(std::time::Duration::from_secs(1)).await; | |
} | |
}); | |
} | |
self | |
} | |
// #[tracing::instrument(skip(self))] | |
pub async fn enqueue<W>(&self, job: NewJob<W>) -> Result<(), Error> | |
where | |
// T: Serialize + Debug, | |
W: Worker, | |
{ | |
let args_json = serde_json::to_string(&job.args)?; | |
let mut conn = self.pool.acquire().await?; | |
sqlx::query("insert into qq_jobs (args, queue) values (?, ?)") | |
.bind(&args_json) | |
.bind(W::queue()) | |
.execute(&mut *conn) | |
.await?; | |
println!("inserted"); | |
Ok(()) | |
} | |
#[tracing::instrument(skip(conn, max_attempts))] | |
async fn try_receive<T>( | |
conn: &mut SqliteConnection, | |
queue: &str, | |
max_attempts: i64, | |
) -> Result<Option<worker::Args<T>>, Error> | |
where | |
T: DeserializeOwned, | |
{ | |
let job: Option<Job> = sqlx::query_as(formatc!( | |
" | |
update qq_jobs | |
set | |
attempts = attempts + 1, | |
status = {}, | |
updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') | |
where id = ( | |
select | |
id | |
from qq_jobs | |
where queue = ? | |
and status = {} | |
and attempts <= ? | |
order by inserted_at asc | |
limit 1 | |
) | |
returning | |
id, | |
args, | |
queue, | |
attempts, | |
inserted_at, | |
updated_at; | |
", | |
JobStatus::Locked as isize, | |
JobStatus::Ready as isize, | |
)) | |
.bind(queue) | |
.bind(max_attempts) | |
.fetch_optional(&mut *conn) | |
.await?; | |
if let Some(job) = job { | |
let args: T = serde_json::from_str(&job.args)?; | |
Ok(Some(worker::Args { | |
args, | |
job_id: job.id, | |
attempts: job.attempts, | |
inserted_at: job.inserted_at, | |
updated_at: job.updated_at, | |
})) | |
} else { | |
Ok(None) | |
} | |
} | |
} | |
enum JobStatus { | |
Ready = 0, | |
Locked = 1, | |
Done = 2, | |
Failed = 3, | |
} | |
#[cfg(test)] | |
mod tests { | |
use std::sync::Arc; | |
use super::*; | |
use crate::worker::{Args, Worker}; | |
use serde::{Deserialize, Serialize}; | |
use tokio::sync::Mutex; | |
#[tokio::test] | |
async fn workers() { | |
#[derive(Debug, Deserialize, Serialize)] | |
struct MyArgs { | |
foo: i64, | |
} | |
#[derive(Serialize, Deserialize, Debug)] | |
struct MyWorker; | |
// #[derive(Serialize, Deserialize, Debug)] | |
// struct MyWorker2; | |
struct MyState { | |
i: usize, | |
} | |
impl Worker for MyWorker { | |
type T = MyArgs; | |
type S = Arc<Mutex<MyState>>; | |
async fn perform(job: Args<Self::T>, state: Self::S) -> Result<(), worker::Error> { | |
let mut state = state.lock().await; | |
println!("XXXXXXXXXXXXXXXXXXXXX job: {:?}", job); | |
println!("XXXXXXXXXXXXXXXXXXXXX old state: {}", state.i); | |
state.i += 1; | |
println!("XXXXXXXXXXXXXXXXXXXXX new state: {}", state.i); | |
drop(state); | |
tokio::time::sleep(std::time::Duration::from_secs(1)).await; | |
Ok(()) | |
} | |
fn queue() -> &'static str { | |
"abc" | |
} | |
} | |
let mut options = Options::default(); | |
options.in_memory = true; | |
let q: Qq = Qq::new(options) | |
.await | |
.unwrap() | |
.add_worker::<MyWorker>(5, Arc::new(Mutex::new(MyState { i: 8 }))) | |
.await; | |
tokio::time::sleep(std::time::Duration::from_secs(3)).await; | |
q.enqueue(MyWorker::new(MyArgs { foo: 1 })).await.unwrap(); | |
tokio::time::sleep(std::time::Duration::from_secs(5)).await; | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use serde::{Serialize, de::DeserializeOwned}; | |
use thiserror::Error; | |
use crate::NewJob; | |
/// all workers implement this trait | |
pub trait Worker { | |
type T: Serialize + DeserializeOwned + Send; | |
type S: Clone + Send + Sync; | |
fn perform( | |
args: Args<Self::T>, | |
state: Self::S, | |
) -> impl std::future::Future<Output = Result<(), Error>> + Send + Sync + 'static | |
where | |
Self: Sized; | |
fn queue() -> &'static str; | |
// fn name() -> &'static str { | |
// std::any::type_name::<Self>() | |
// } | |
fn new(args: Self::T) -> NewJob<Self> | |
where | |
Self: Sized, | |
{ | |
NewJob { args } | |
} | |
} | |
/// a blank default state object | |
pub struct NoState; | |
/// the args that are passed to a worker | |
#[derive(Debug)] | |
pub struct Args<T> { | |
pub args: T, | |
pub(crate) job_id: i64, | |
pub(crate) attempts: i64, | |
pub(crate) inserted_at: String, | |
pub(crate) updated_at: String, | |
} | |
/// the worker must return this if it ends in a non-ok state | |
#[derive(Error, Debug)] | |
pub enum Error { | |
#[error("cancel")] | |
Cancel, | |
#[error("error")] | |
Error, | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment