Skip to content

Instantly share code, notes, and snippets.

@ferromir
Created April 2, 2025 20:44
Show Gist options
  • Select an option

  • Save ferromir/2d91aae5fe56e7250ec60b85191f1557 to your computer and use it in GitHub Desktop.

Select an option

Save ferromir/2d91aae5fe56e7250ec60b85191f1557 to your computer and use it in GitHub Desktop.
Lides transalated to Rust by Claude 3.7
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::time::sleep;
use anyhow::{Result, anyhow};
const DEFAULT_MAX_FAILURES: u32 = 3;
const DEFAULT_TIMEOUT_MS: u64 = 60_000; // 1m
const DEFAULT_POLL_MS: u64 = 1_000; // 1s
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum Status {
#[serde(rename = "idle")]
Idle,
#[serde(rename = "running")]
Running,
#[serde(rename = "failed")]
Failed,
#[serde(rename = "finished")]
Finished,
#[serde(rename = "aborted")]
Aborted,
}
impl Status {
pub fn as_str(&self) -> &'static str {
match self {
Status::Idle => "idle",
Status::Running => "running",
Status::Failed => "failed",
Status::Finished => "finished",
Status::Aborted => "aborted",
}
}
}
// Type alias for handler function
pub type HandlerFn = Arc<dyn Fn(Context, Value) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync>;
#[derive(Clone)]
pub struct Context {
step_fn: Arc<dyn Fn(String, Box<dyn FnOnce() -> Pin<Box<dyn Future<Output = Result<Value>> + Send>>>) -> Pin<Box<dyn Future<Output = Result<Value>> + Send>> + Send + Sync>,
sleep_fn: Arc<dyn Fn(String, u64) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync>,
start_fn: Arc<dyn Fn(String, String, Value) -> Pin<Box<dyn Future<Output = Result<bool>> + Send>> + Send + Sync>,
}
impl Context {
/// Executes a step.
pub async fn step<F, Fut, T>(&self, id: &str, f: F) -> Result<T>
where
F: FnOnce() -> Fut + 'static,
Fut: Future<Output = Result<T>> + Send + 'static,
T: Into<Value> + From<Value> + 'static,
{
let boxed_fn = Box::new(|| -> Pin<Box<dyn Future<Output = Result<Value>> + Send>> {
Box::pin(async move {
let result = f().await?;
Ok(serde_json::to_value(result)?)
})
});
let result = (self.step_fn)(id.to_string(), boxed_fn).await?;
Ok(serde_json::from_value(result)?)
}
/// Puts the workflow to sleep.
pub async fn sleep(&self, id: &str, ms: u64) -> Result<()> {
(self.sleep_fn)(id.to_string(), ms).await
}
/// Starts a new workflow.
pub async fn start<T: Into<Value>>(&self, id: &str, handler: &str, input: T) -> Result<bool> {
(self.start_fn)(id.to_string(), handler.to_string(), input.into()).await
}
}
/// Client for workflow operations
#[derive(Clone)]
pub struct Client {
start_fn: Arc<dyn Fn(String, String, Value) -> Pin<Box<dyn Future<Output = Result<bool>> + Send>> + Send + Sync>,
wait_fn: Arc<dyn Fn(String, Vec<Status>, u32, u64) -> Pin<Box<dyn Future<Output = Result<Option<Status>>> + Send>> + Send + Sync>,
poll_fn: Arc<dyn Fn(Box<dyn Fn() -> bool + Send>) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync>,
}
impl Client {
/// It starts a workflow.
pub async fn start<T: Into<Value>>(&self, id: &str, handler: &str, input: T) -> Result<bool> {
(self.start_fn)(id.to_string(), handler.to_string(), input.into()).await
}
/// Returns a matching workflow status if found, it retries for the specified
/// amount of times and it pauses in between.
pub async fn wait(&self, id: &str, status: Vec<Status>, times: u32, ms: u64) -> Result<Option<Status>> {
(self.wait_fn)(id.to_string(), status, times, ms).await
}
/// It starts polling workflows.
pub async fn poll<F>(&self, should_stop: F) -> Result<()>
where
F: Fn() -> bool + Send + 'static,
{
(self.poll_fn)(Box::new(should_stop)).await
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RunData {
pub handler: String,
pub input: Value,
pub failures: Option<u32>,
}
pub struct Config {
pub handlers: HashMap<String, HandlerFn>,
pub persistence: Arc<dyn Persistence>,
pub max_failures: Option<u32>,
pub timeout_interval_ms: Option<u64>,
pub poll_interval_ms: Option<u64>,
}
#[async_trait]
pub trait Persistence: Send + Sync {
/// Initializes the persistence provider.
async fn init(&self) -> Result<()>;
/// Inserts a workflow.
async fn insert(&self, workflow_id: &str, handler: &str, input: Value) -> Result<bool>;
/// It consists of two actions:
/// 1. Find a workflow that is ready to run.
/// 2. Update the timeout and set the status to "running".
/// These 2 steps have to be performed atomically.
///
/// A "ready to run" workflow matches the following condition:
/// (status is "idle") OR
/// (status is "running" AND timeoutAt < CURRENT_TIME) OR
/// (status is "failed" AND timeoutAt < CURRENT_TIME)
async fn claim(&self, now: SystemTime, timeout_at: SystemTime) -> Result<Option<String>>;
/// Finds the stored output for the given workflow and step.
async fn find_output(&self, workflow_id: &str, step_id: &str) -> Result<Option<Value>>;
/// Finds the stored wake up time for the given workflow and nap.
async fn find_wake_up_at(&self, workflow_id: &str, nap_id: &str) -> Result<Option<SystemTime>>;
/// Finds information about the workflow required to run it.
async fn find_run_data(&self, workflow_id: &str) -> Result<Option<RunData>>;
/// It sets the status of the workflow to "finished".
async fn set_as_finished(&self, workflow_id: &str) -> Result<()>;
/// Finds the status of a workflow.
async fn find_status(&self, workflow_id: &str) -> Result<Option<Status>>;
/// Updates the status, timeoutAt, failures and lastError.
async fn update_status(
&self,
workflow_id: &str,
status: Status,
timeout_at: SystemTime,
failures: u32,
last_error: &str,
) -> Result<()>;
/// Updates the step's output and timeoutAt.
async fn update_output(
&self,
workflow_id: &str,
step_id: &str,
output: Value,
timeout_at: SystemTime,
) -> Result<()>;
/// Updates the step's output and timeoutAt.
async fn update_wake_up_at(
&self,
workflow_id: &str,
nap_id: &str,
wake_up_at: SystemTime,
timeout_at: SystemTime,
) -> Result<()>;
}
/// Sleep for the specified milliseconds
async fn go_sleep(ms: u64) -> Result<()> {
sleep(Duration::from_millis(ms)).await;
Ok(())
}
/// Create function to claim a workflow ready to run
fn make_claim(
persistence: Arc<dyn Persistence>,
timeout_interval_ms: u64,
) -> impl Fn() -> Pin<Box<dyn Future<Output = Result<Option<String>>> + Send>> + Send + Sync {
move || {
let persistence = persistence.clone();
Box::pin(async move {
let now = SystemTime::now();
let timeout_at = now + Duration::from_millis(timeout_interval_ms);
persistence.claim(now, timeout_at).await
})
}
}
/// Create function to execute workflow steps
fn make_make_step(
persistence: Arc<dyn Persistence>,
timeout_interval_ms: u64,
) -> impl Fn(String) -> impl Fn(String, Box<dyn FnOnce() -> Pin<Box<dyn Future<Output = Result<Value>> + Send>>>) -> Pin<Box<dyn Future<Output = Result<Value>> + Send>> + Send + Sync {
move |workflow_id| {
let persistence = persistence.clone();
let workflow_id = workflow_id.clone();
move |step_id, f| {
let persistence = persistence.clone();
let workflow_id = workflow_id.clone();
let step_id = step_id.clone();
Box::pin(async move {
if let Some(output) = persistence.find_output(&workflow_id, &step_id).await? {
return Ok(output);
}
let output = f().await?;
let now = SystemTime::now();
let timeout_at = now + Duration::from_millis(timeout_interval_ms);
persistence.update_output(&workflow_id, &step_id, output.clone(), timeout_at).await?;
Ok(output)
})
}
}
}
/// Create function to handle workflow sleep operations
fn make_make_sleep(
persistence: Arc<dyn Persistence>,
timeout_interval_ms: u64,
) -> impl Fn(String) -> impl Fn(String, u64) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync {
move |workflow_id| {
let persistence = persistence.clone();
let workflow_id = workflow_id.clone();
move |nap_id, ms| {
let persistence = persistence.clone();
let workflow_id = workflow_id.clone();
let nap_id = nap_id.clone();
Box::pin(async move {
let wake_up_at = persistence.find_wake_up_at(&workflow_id, &nap_id).await?;
let now = SystemTime::now();
if let Some(wake_up_at) = wake_up_at {
if let Ok(remaining_duration) = wake_up_at.duration_since(now) {
let remaining_ms = remaining_duration.as_millis() as u64;
if remaining_ms > 0 {
go_sleep(remaining_ms).await?;
}
return Ok(());
}
return Ok(());
}
let wake_up_at = now + Duration::from_millis(ms);
let timeout_at = wake_up_at + Duration::from_millis(timeout_interval_ms);
persistence.update_wake_up_at(&workflow_id, &nap_id, wake_up_at, timeout_at).await?;
go_sleep(ms).await
})
}
}
}
/// Create function to execute a workflow
fn make_run(
persistence: Arc<dyn Persistence>,
handlers: HashMap<String, HandlerFn>,
make_step: impl Fn(String) -> impl Fn(String, Box<dyn FnOnce() -> Pin<Box<dyn Future<Output = Result<Value>> + Send>>>) -> Pin<Box<dyn Future<Output = Result<Value>> + Send>> + Send + Sync + Clone,
make_sleep: impl Fn(String) -> impl Fn(String, u64) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync + Clone,
start: impl Fn(String, String, Value) -> Pin<Box<dyn Future<Output = Result<bool>> + Send>> + Send + Sync + Clone,
max_failures: u32,
timeout_interval_ms: u64,
) -> impl Fn(String) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync {
move |workflow_id| {
let persistence = persistence.clone();
let handlers = handlers.clone();
let make_step = make_step.clone();
let make_sleep = make_sleep.clone();
let start = start.clone();
let workflow_id = workflow_id.clone();
Box::pin(async move {
let run_data = match persistence.find_run_data(&workflow_id).await? {
Some(data) => data,
None => return Err(anyhow!("workflow not found: {}", workflow_id)),
};
let handler = match handlers.get(&run_data.handler) {
Some(h) => h.clone(),
None => return Err(anyhow!("handler not found: {}", run_data.handler)),
};
let step_fn = {
let workflow_id_clone = workflow_id.clone();
let make_step = make_step.clone();
Arc::new(move |id: String, f| {
(make_step(workflow_id_clone.clone()))(id, f)
})
};
let sleep_fn = {
let workflow_id_clone = workflow_id.clone();
let make_sleep = make_sleep.clone();
Arc::new(move |id: String, ms: u64| {
(make_sleep(workflow_id_clone.clone()))(id, ms)
})
};
let start_fn = {
let start = start.clone();
Arc::new(move |id: String, handler: String, input: Value| {
(start)(id, handler, input)
})
};
let ctx = Context {
step_fn,
sleep_fn,
start_fn,
};
match handler(ctx, run_data.input.clone()).await {
Ok(_) => {
persistence.set_as_finished(&workflow_id).await?;
Ok(())
}
Err(error) => {
let last_error = error.to_string();
let failures = run_data.failures.unwrap_or(0) + 1;
let status = if failures < max_failures { Status::Failed } else { Status::Aborted };
let now = SystemTime::now();
let timeout_at = now + Duration::from_millis(timeout_interval_ms);
persistence.update_status(&workflow_id, status, timeout_at, failures, &last_error).await?;
Ok(())
}
}
})
}
}
/// Create function to start a workflow
fn make_start(
persistence: Arc<dyn Persistence>,
) -> impl Fn(String, String, Value) -> Pin<Box<dyn Future<Output = Result<bool>> + Send>> + Send + Sync {
move |workflow_id, handler, input| {
let persistence = persistence.clone();
Box::pin(async move {
persistence.insert(&workflow_id, &handler, input).await
})
}
}
/// Create function to wait for a workflow to reach a specific status
fn make_wait(
persistence: Arc<dyn Persistence>,
) -> impl Fn(String, Vec<Status>, u32, u64) -> Pin<Box<dyn Future<Output = Result<Option<Status>>> + Send>> + Send + Sync {
move |workflow_id, status, times, ms| {
let persistence = persistence.clone();
Box::pin(async move {
for _ in 0..times {
if let Some(found) = persistence.find_status(&workflow_id).await? {
if status.contains(&found) {
return Ok(Some(found));
}
}
go_sleep(ms).await?;
}
Ok(None)
})
}
}
/// Create function to poll for workflows to execute
fn make_poll(
claim: impl Fn() -> Pin<Box<dyn Future<Output = Result<Option<String>>> + Send>> + Send + Sync + Clone,
run: impl Fn(String) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync + Clone,
poll_interval_ms: u64,
) -> impl Fn(Box<dyn Fn() -> bool + Send>) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync {
move |should_stop| {
let claim = claim.clone();
let run = run.clone();
Box::pin(async move {
while !should_stop() {
match claim().await? {
Some(workflow_id) => {
// Intentionally not awaiting
let run_clone = run.clone();
tokio::spawn(async move {
let _ = run_clone(workflow_id).await;
});
}
None => {
go_sleep(poll_interval_ms).await?;
}
}
}
Ok(())
})
}
}
/// Creates a client based on the given configuration
pub async fn make_client(config: Config) -> Result<Client> {
config.persistence.init().await?;
let max_failures = config.max_failures.unwrap_or(DEFAULT_MAX_FAILURES);
let timeout_interval_ms = config.timeout_interval_ms.unwrap_or(DEFAULT_TIMEOUT_MS);
let poll_interval_ms = config.poll_interval_ms.unwrap_or(DEFAULT_POLL_MS);
let persistence = config.persistence.clone();
let start = make_start(persistence.clone());
let wait = make_wait(persistence.clone());
let claim = make_claim(persistence.clone(), timeout_interval_ms);
let make_step = make_make_step(persistence.clone(), timeout_interval_ms);
let make_sleep = make_make_sleep(persistence.clone(), timeout_interval_ms);
let run = make_run(
persistence.clone(),
config.handlers,
make_step,
make_sleep,
start.clone(),
max_failures,
timeout_interval_ms,
);
let poll = make_poll(claim, run, poll_interval_ms);
let start_fn = Arc::new(move |id, handler, input| {
let start = start.clone();
Box::pin(async move {
start(id, handler, input).await
})
});
let wait_fn = Arc::new(move |id, status, times, ms| {
let wait = wait.clone();
Box::pin(async move {
wait(id, status, times, ms).await
})
});
let poll_fn = Arc::new(move |should_stop| {
let poll = poll.clone();
Box::pin(async move {
poll(should_stop).await
})
});
Ok(Client {
start_fn,
wait_fn,
poll_fn,
})
}
// For internal testing
pub mod internal_testing {
use super::*;
pub fn make_claim_fn(
persistence: Arc<dyn Persistence>,
timeout_interval_ms: u64,
) -> impl Fn() -> Pin<Box<dyn Future<Output = Result<Option<String>>> + Send>> + Send + Sync {
make_claim(persistence, timeout_interval_ms)
}
pub fn make_make_step_fn(
persistence: Arc<dyn Persistence>,
timeout_interval_ms: u64,
) -> impl Fn(String) -> impl Fn(String, Box<dyn FnOnce() -> Pin<Box<dyn Future<Output = Result<Value>> + Send>>>) -> Pin<Box<dyn Future<Output = Result<Value>> + Send>> + Send + Sync {
make_make_step(persistence, timeout_interval_ms)
}
pub fn make_make_sleep_fn(
persistence: Arc<dyn Persistence>,
timeout_interval_ms: u64,
) -> impl Fn(String) -> impl Fn(String, u64) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> + Send + Sync {
make_make_sleep(persistence, timeout_interval_ms)
}
// Additional test functions would be defined here, similar to above
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment