Skip to content

Instantly share code, notes, and snippets.

@matthewjberger
Created August 27, 2024 16:52
Show Gist options
  • Save matthewjberger/4c3a7c1ddae0b25f9b13df59f30bfbe4 to your computer and use it in GitHub Desktop.
Save matthewjberger/4c3a7c1ddae0b25f9b13df59f30bfbe4 to your computer and use it in GitHub Desktop.
Rust Job Graphs
// [dependencies]
// futures = "0.3.30"
// petgraph = "0.6.5"
// tokio = { version = "1.39.3", features = ["full"] }
use futures::future::join_all;
use petgraph::graph::{DiGraph, NodeIndex};
use std::{
collections::HashMap,
sync::{Arc, Mutex},
};
use tokio::sync::{mpsc, Mutex as AsyncMutex};
#[derive(Debug, Clone)]
struct JobContext {
shared_data: Arc<Mutex<HashMap<String, String>>>,
}
impl JobContext {
fn new() -> Self {
JobContext {
shared_data: Arc::new(Mutex::new(HashMap::new())),
}
}
fn set(&self, key: &str, value: &str) {
let mut data = self.shared_data.lock().unwrap();
data.insert(key.to_string(), value.to_string());
}
fn get(&self, key: &str) -> Option<String> {
let data = self.shared_data.lock().unwrap();
data.get(key).cloned()
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum Job {
SimpleTask(String),
DataProcessing { data: Vec<u8> },
NetworkRequest { url: String },
ContextOperation { key: String, value: Option<String> },
}
#[derive(Debug, Clone)]
enum JobGraphDsl {
Job(Job),
Then(Box<JobGraphDsl>, Box<JobGraphDsl>),
Parallel(Vec<JobGraphDsl>),
}
struct JobGraph {
graph: DiGraph<Job, ()>,
node_map: HashMap<Job, NodeIndex>,
}
impl JobGraph {
fn new() -> Self {
JobGraph {
graph: DiGraph::new(),
node_map: HashMap::new(),
}
}
pub fn add_job(&mut self, job: Job) -> NodeIndex {
if let Some(&node_index) = self.node_map.get(&job) {
return node_index;
}
let node_index = self.graph.add_node(job.clone());
self.node_map.insert(job, node_index);
node_index
}
#[allow(dead_code)]
pub fn add_dependency(&mut self, job: Job, dependency: Job) {
let job_index = self.add_job(job);
let dependency_index = self.add_job(dependency);
self.graph.add_edge(job_index, dependency_index, ());
}
pub fn from_dsl(&mut self, dsl: JobGraphDsl) -> NodeIndex {
match dsl {
JobGraphDsl::Job(job) => self.add_job(job),
JobGraphDsl::Then(job1_dsl, job2_dsl) => {
let job1_index = self.from_dsl(*job1_dsl);
let job2_index = self.from_dsl(*job2_dsl);
self.graph.add_edge(job1_index, job2_index, ());
job2_index
}
JobGraphDsl::Parallel(jobs_dsl) => {
let mut last_index = None;
for job_dsl in jobs_dsl {
let job_index = self.from_dsl(job_dsl);
if let Some(last_index) = last_index {
self.graph.add_edge(last_index, job_index, ());
}
last_index = Some(job_index);
}
last_index.unwrap()
}
}
}
}
struct JobRunner {
graph: Arc<JobGraph>,
in_progress: Arc<AsyncMutex<HashMap<Job, bool>>>,
log: Arc<Mutex<Vec<String>>>,
context: Arc<JobContext>,
}
impl JobRunner {
fn new(graph: Arc<JobGraph>, log: Arc<Mutex<Vec<String>>>) -> Self {
JobRunner {
graph,
in_progress: Arc::new(AsyncMutex::new(HashMap::new())),
log,
context: Arc::new(JobContext::new()),
}
}
async fn run(&self) {
let mut tasks = Vec::new();
let graph = Arc::clone(&self.graph);
let (tx, mut rx) = mpsc::channel::<Job>(100);
for node_index in graph.graph.node_indices() {
let job = graph.graph.node_weight(node_index).unwrap().clone();
let in_progress = Arc::clone(&self.in_progress);
let log = Arc::clone(&self.log);
let context = Arc::clone(&self.context);
let tx = tx.clone();
let task = tokio::spawn(async move {
let mut in_progress = in_progress.lock().await;
if !in_progress.contains_key(&job) {
in_progress.insert(job.clone(), true);
drop(in_progress);
Self::execute_job(job.clone(), log.clone(), context).await;
tx.send(job.clone()).await.unwrap();
}
});
tasks.push(task);
}
tokio::spawn(async move {
while let Some(job) = rx.recv().await {
println!("Job completed: {:?}", job);
}
});
let _ = join_all(tasks).await;
}
async fn execute_job(job: Job, log: Arc<Mutex<Vec<String>>>, context: Arc<JobContext>) {
match job {
Job::SimpleTask(name) => {
{
let mut log = log.lock().unwrap();
log.push(format!("Executing Simple Task: {}", name));
}
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
}
Job::DataProcessing { data } => {
{
let mut log = log.lock().unwrap();
log.push(format!("Processing Data: {:?}", data));
}
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
}
Job::NetworkRequest { url } => {
{
let mut log = log.lock().unwrap();
log.push(format!("Making Network Request to: {}", url));
}
tokio::time::sleep(tokio::time::Duration::from_secs(3)).await;
}
Job::ContextOperation { key, value } => match value {
Some(val) => {
context.set(&key, &val);
let mut log = log.lock().unwrap();
log.push(format!("Set context: {} = {}", key, val));
}
None => {
if let Some(val) = context.get(&key) {
let mut log = log.lock().unwrap();
log.push(format!("Get context: {} = {}", key, val));
} else {
let mut log = log.lock().unwrap();
log.push(format!("Context key not found: {}", key));
}
}
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_simple_execution() {
let log = Arc::new(Mutex::new(Vec::new()));
let job_dsl = JobGraphDsl::Job(Job::SimpleTask("Task 1".to_string()));
let mut job_graph = JobGraph::new();
job_graph.from_dsl(job_dsl);
let runner = JobRunner::new(Arc::new(job_graph), Arc::clone(&log));
runner.run().await;
let log = log.lock().unwrap();
assert_eq!(log.len(), 1);
assert_eq!(log[0], "Executing Simple Task: Task 1");
}
#[tokio::test]
async fn test_parallel_execution() {
let log = Arc::new(Mutex::new(Vec::new()));
let job_dsl = JobGraphDsl::Parallel(vec![
JobGraphDsl::Job(Job::SimpleTask("Task 1".to_string())),
JobGraphDsl::Job(Job::SimpleTask("Task 2".to_string())),
]);
let mut job_graph = JobGraph::new();
job_graph.from_dsl(job_dsl);
let runner = JobRunner::new(Arc::new(job_graph), Arc::clone(&log));
runner.run().await;
let log = log.lock().unwrap();
assert_eq!(log.len(), 2);
assert!(log.contains(&"Executing Simple Task: Task 1".to_string()));
assert!(log.contains(&"Executing Simple Task: Task 2".to_string()));
}
#[tokio::test]
async fn test_sequential_execution() {
let log = Arc::new(Mutex::new(Vec::new()));
let job_dsl = JobGraphDsl::Then(
Box::new(JobGraphDsl::Job(Job::SimpleTask("Task 1".to_string()))),
Box::new(JobGraphDsl::Job(Job::SimpleTask("Task 2".to_string()))),
);
let mut job_graph = JobGraph::new();
job_graph.from_dsl(job_dsl);
let runner = JobRunner::new(Arc::new(job_graph), Arc::clone(&log));
runner.run().await;
let log = log.lock().unwrap();
assert_eq!(log.len(), 2);
assert_eq!(log[0], "Executing Simple Task: Task 1");
assert_eq!(log[1], "Executing Simple Task: Task 2");
}
#[tokio::test]
async fn test_complex_execution() {
let log = Arc::new(Mutex::new(Vec::new()));
let job_dsl = JobGraphDsl::Then(
Box::new(JobGraphDsl::Job(Job::SimpleTask("Task 1".to_string()))),
Box::new(JobGraphDsl::Parallel(vec![
JobGraphDsl::Job(Job::DataProcessing {
data: vec![1, 2, 3],
}),
JobGraphDsl::Job(Job::NetworkRequest {
url: "https://example.com".to_string(),
}),
])),
);
let mut job_graph = JobGraph::new();
job_graph.from_dsl(job_dsl);
let runner = JobRunner::new(Arc::new(job_graph), Arc::clone(&log));
runner.run().await;
let log = log.lock().unwrap();
assert_eq!(log.len(), 3);
assert_eq!(log[0], "Executing Simple Task: Task 1");
assert!(log.contains(&"Processing Data: [1, 2, 3]".to_string()));
assert!(log.contains(&"Making Network Request to: https://example.com".to_string()));
}
#[tokio::test]
async fn test_context_operations() {
let log = Arc::new(Mutex::new(Vec::new()));
let job_dsl = JobGraphDsl::Then(
Box::new(JobGraphDsl::Job(Job::ContextOperation {
key: "test_key".to_string(),
value: Some("test_value".to_string()),
})),
Box::new(JobGraphDsl::Job(Job::ContextOperation {
key: "test_key".to_string(),
value: None,
})),
);
let mut job_graph = JobGraph::new();
job_graph.from_dsl(job_dsl);
let runner = JobRunner::new(Arc::new(job_graph), Arc::clone(&log));
runner.run().await;
let log = log.lock().unwrap();
assert_eq!(log.len(), 2);
assert_eq!(log[0], "Set context: test_key = test_value");
assert_eq!(log[1], "Get context: test_key = test_value");
}
}
#[tokio::main]
async fn main() {
println!("Job Graph Program");
// Create a sample job graph
let mut job_graph = JobGraph::new();
let job_dsl = JobGraphDsl::Then(
Box::new(JobGraphDsl::Job(Job::SimpleTask("Initialize".to_string()))),
Box::new(JobGraphDsl::Parallel(vec![
JobGraphDsl::Job(Job::ContextOperation {
key: "data".to_string(),
value: Some("important_value".to_string()),
}),
JobGraphDsl::Job(Job::DataProcessing {
data: vec![1, 2, 3],
}),
JobGraphDsl::Job(Job::NetworkRequest {
url: "https://example.com".to_string(),
}),
])),
);
job_graph.from_dsl(job_dsl);
// Create a shared log
let log = Arc::new(Mutex::new(Vec::new()));
// Create and run the JobRunner
let runner = JobRunner::new(Arc::new(job_graph), Arc::clone(&log));
runner.run().await;
// Print the execution log
println!("Execution Log:");
let log = log.lock().unwrap();
for (index, entry) in log.iter().enumerate() {
println!("{}. {}", index + 1, entry);
}
// Access the final state of the shared context
let context = runner.context.clone();
if let Some(value) = context.get("data") {
println!("Final value in shared context: data = {}", value);
} else {
println!("No value found for 'data' in shared context");
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment