Created
August 27, 2024 16:52
-
-
Save matthewjberger/4c3a7c1ddae0b25f9b13df59f30bfbe4 to your computer and use it in GitHub Desktop.
Rust Job Graphs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// [dependencies] | |
// futures = "0.3.30" | |
// petgraph = "0.6.5" | |
// tokio = { version = "1.39.3", features = ["full"] } | |
use futures::future::join_all; | |
use petgraph::graph::{DiGraph, NodeIndex}; | |
use std::{ | |
collections::HashMap, | |
sync::{Arc, Mutex}, | |
}; | |
use tokio::sync::{mpsc, Mutex as AsyncMutex}; | |
#[derive(Debug, Clone)] | |
struct JobContext { | |
shared_data: Arc<Mutex<HashMap<String, String>>>, | |
} | |
impl JobContext { | |
fn new() -> Self { | |
JobContext { | |
shared_data: Arc::new(Mutex::new(HashMap::new())), | |
} | |
} | |
fn set(&self, key: &str, value: &str) { | |
let mut data = self.shared_data.lock().unwrap(); | |
data.insert(key.to_string(), value.to_string()); | |
} | |
fn get(&self, key: &str) -> Option<String> { | |
let data = self.shared_data.lock().unwrap(); | |
data.get(key).cloned() | |
} | |
} | |
#[derive(Debug, Clone, PartialEq, Eq, Hash)] | |
enum Job { | |
SimpleTask(String), | |
DataProcessing { data: Vec<u8> }, | |
NetworkRequest { url: String }, | |
ContextOperation { key: String, value: Option<String> }, | |
} | |
#[derive(Debug, Clone)] | |
enum JobGraphDsl { | |
Job(Job), | |
Then(Box<JobGraphDsl>, Box<JobGraphDsl>), | |
Parallel(Vec<JobGraphDsl>), | |
} | |
struct JobGraph { | |
graph: DiGraph<Job, ()>, | |
node_map: HashMap<Job, NodeIndex>, | |
} | |
impl JobGraph { | |
fn new() -> Self { | |
JobGraph { | |
graph: DiGraph::new(), | |
node_map: HashMap::new(), | |
} | |
} | |
pub fn add_job(&mut self, job: Job) -> NodeIndex { | |
if let Some(&node_index) = self.node_map.get(&job) { | |
return node_index; | |
} | |
let node_index = self.graph.add_node(job.clone()); | |
self.node_map.insert(job, node_index); | |
node_index | |
} | |
#[allow(dead_code)] | |
pub fn add_dependency(&mut self, job: Job, dependency: Job) { | |
let job_index = self.add_job(job); | |
let dependency_index = self.add_job(dependency); | |
self.graph.add_edge(job_index, dependency_index, ()); | |
} | |
pub fn from_dsl(&mut self, dsl: JobGraphDsl) -> NodeIndex { | |
match dsl { | |
JobGraphDsl::Job(job) => self.add_job(job), | |
JobGraphDsl::Then(job1_dsl, job2_dsl) => { | |
let job1_index = self.from_dsl(*job1_dsl); | |
let job2_index = self.from_dsl(*job2_dsl); | |
self.graph.add_edge(job1_index, job2_index, ()); | |
job2_index | |
} | |
JobGraphDsl::Parallel(jobs_dsl) => { | |
let mut last_index = None; | |
for job_dsl in jobs_dsl { | |
let job_index = self.from_dsl(job_dsl); | |
if let Some(last_index) = last_index { | |
self.graph.add_edge(last_index, job_index, ()); | |
} | |
last_index = Some(job_index); | |
} | |
last_index.unwrap() | |
} | |
} | |
} | |
} | |
struct JobRunner { | |
graph: Arc<JobGraph>, | |
in_progress: Arc<AsyncMutex<HashMap<Job, bool>>>, | |
log: Arc<Mutex<Vec<String>>>, | |
context: Arc<JobContext>, | |
} | |
impl JobRunner { | |
fn new(graph: Arc<JobGraph>, log: Arc<Mutex<Vec<String>>>) -> Self { | |
JobRunner { | |
graph, | |
in_progress: Arc::new(AsyncMutex::new(HashMap::new())), | |
log, | |
context: Arc::new(JobContext::new()), | |
} | |
} | |
async fn run(&self) { | |
let mut tasks = Vec::new(); | |
let graph = Arc::clone(&self.graph); | |
let (tx, mut rx) = mpsc::channel::<Job>(100); | |
for node_index in graph.graph.node_indices() { | |
let job = graph.graph.node_weight(node_index).unwrap().clone(); | |
let in_progress = Arc::clone(&self.in_progress); | |
let log = Arc::clone(&self.log); | |
let context = Arc::clone(&self.context); | |
let tx = tx.clone(); | |
let task = tokio::spawn(async move { | |
let mut in_progress = in_progress.lock().await; | |
if !in_progress.contains_key(&job) { | |
in_progress.insert(job.clone(), true); | |
drop(in_progress); | |
Self::execute_job(job.clone(), log.clone(), context).await; | |
tx.send(job.clone()).await.unwrap(); | |
} | |
}); | |
tasks.push(task); | |
} | |
tokio::spawn(async move { | |
while let Some(job) = rx.recv().await { | |
println!("Job completed: {:?}", job); | |
} | |
}); | |
let _ = join_all(tasks).await; | |
} | |
async fn execute_job(job: Job, log: Arc<Mutex<Vec<String>>>, context: Arc<JobContext>) { | |
match job { | |
Job::SimpleTask(name) => { | |
{ | |
let mut log = log.lock().unwrap(); | |
log.push(format!("Executing Simple Task: {}", name)); | |
} | |
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; | |
} | |
Job::DataProcessing { data } => { | |
{ | |
let mut log = log.lock().unwrap(); | |
log.push(format!("Processing Data: {:?}", data)); | |
} | |
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await; | |
} | |
Job::NetworkRequest { url } => { | |
{ | |
let mut log = log.lock().unwrap(); | |
log.push(format!("Making Network Request to: {}", url)); | |
} | |
tokio::time::sleep(tokio::time::Duration::from_secs(3)).await; | |
} | |
Job::ContextOperation { key, value } => match value { | |
Some(val) => { | |
context.set(&key, &val); | |
let mut log = log.lock().unwrap(); | |
log.push(format!("Set context: {} = {}", key, val)); | |
} | |
None => { | |
if let Some(val) = context.get(&key) { | |
let mut log = log.lock().unwrap(); | |
log.push(format!("Get context: {} = {}", key, val)); | |
} else { | |
let mut log = log.lock().unwrap(); | |
log.push(format!("Context key not found: {}", key)); | |
} | |
} | |
}, | |
} | |
} | |
} | |
#[cfg(test)] | |
mod tests { | |
use super::*; | |
#[tokio::test] | |
async fn test_simple_execution() { | |
let log = Arc::new(Mutex::new(Vec::new())); | |
let job_dsl = JobGraphDsl::Job(Job::SimpleTask("Task 1".to_string())); | |
let mut job_graph = JobGraph::new(); | |
job_graph.from_dsl(job_dsl); | |
let runner = JobRunner::new(Arc::new(job_graph), Arc::clone(&log)); | |
runner.run().await; | |
let log = log.lock().unwrap(); | |
assert_eq!(log.len(), 1); | |
assert_eq!(log[0], "Executing Simple Task: Task 1"); | |
} | |
#[tokio::test] | |
async fn test_parallel_execution() { | |
let log = Arc::new(Mutex::new(Vec::new())); | |
let job_dsl = JobGraphDsl::Parallel(vec![ | |
JobGraphDsl::Job(Job::SimpleTask("Task 1".to_string())), | |
JobGraphDsl::Job(Job::SimpleTask("Task 2".to_string())), | |
]); | |
let mut job_graph = JobGraph::new(); | |
job_graph.from_dsl(job_dsl); | |
let runner = JobRunner::new(Arc::new(job_graph), Arc::clone(&log)); | |
runner.run().await; | |
let log = log.lock().unwrap(); | |
assert_eq!(log.len(), 2); | |
assert!(log.contains(&"Executing Simple Task: Task 1".to_string())); | |
assert!(log.contains(&"Executing Simple Task: Task 2".to_string())); | |
} | |
#[tokio::test] | |
async fn test_sequential_execution() { | |
let log = Arc::new(Mutex::new(Vec::new())); | |
let job_dsl = JobGraphDsl::Then( | |
Box::new(JobGraphDsl::Job(Job::SimpleTask("Task 1".to_string()))), | |
Box::new(JobGraphDsl::Job(Job::SimpleTask("Task 2".to_string()))), | |
); | |
let mut job_graph = JobGraph::new(); | |
job_graph.from_dsl(job_dsl); | |
let runner = JobRunner::new(Arc::new(job_graph), Arc::clone(&log)); | |
runner.run().await; | |
let log = log.lock().unwrap(); | |
assert_eq!(log.len(), 2); | |
assert_eq!(log[0], "Executing Simple Task: Task 1"); | |
assert_eq!(log[1], "Executing Simple Task: Task 2"); | |
} | |
#[tokio::test] | |
async fn test_complex_execution() { | |
let log = Arc::new(Mutex::new(Vec::new())); | |
let job_dsl = JobGraphDsl::Then( | |
Box::new(JobGraphDsl::Job(Job::SimpleTask("Task 1".to_string()))), | |
Box::new(JobGraphDsl::Parallel(vec![ | |
JobGraphDsl::Job(Job::DataProcessing { | |
data: vec![1, 2, 3], | |
}), | |
JobGraphDsl::Job(Job::NetworkRequest { | |
url: "https://example.com".to_string(), | |
}), | |
])), | |
); | |
let mut job_graph = JobGraph::new(); | |
job_graph.from_dsl(job_dsl); | |
let runner = JobRunner::new(Arc::new(job_graph), Arc::clone(&log)); | |
runner.run().await; | |
let log = log.lock().unwrap(); | |
assert_eq!(log.len(), 3); | |
assert_eq!(log[0], "Executing Simple Task: Task 1"); | |
assert!(log.contains(&"Processing Data: [1, 2, 3]".to_string())); | |
assert!(log.contains(&"Making Network Request to: https://example.com".to_string())); | |
} | |
#[tokio::test] | |
async fn test_context_operations() { | |
let log = Arc::new(Mutex::new(Vec::new())); | |
let job_dsl = JobGraphDsl::Then( | |
Box::new(JobGraphDsl::Job(Job::ContextOperation { | |
key: "test_key".to_string(), | |
value: Some("test_value".to_string()), | |
})), | |
Box::new(JobGraphDsl::Job(Job::ContextOperation { | |
key: "test_key".to_string(), | |
value: None, | |
})), | |
); | |
let mut job_graph = JobGraph::new(); | |
job_graph.from_dsl(job_dsl); | |
let runner = JobRunner::new(Arc::new(job_graph), Arc::clone(&log)); | |
runner.run().await; | |
let log = log.lock().unwrap(); | |
assert_eq!(log.len(), 2); | |
assert_eq!(log[0], "Set context: test_key = test_value"); | |
assert_eq!(log[1], "Get context: test_key = test_value"); | |
} | |
} | |
#[tokio::main] | |
async fn main() { | |
println!("Job Graph Program"); | |
// Create a sample job graph | |
let mut job_graph = JobGraph::new(); | |
let job_dsl = JobGraphDsl::Then( | |
Box::new(JobGraphDsl::Job(Job::SimpleTask("Initialize".to_string()))), | |
Box::new(JobGraphDsl::Parallel(vec![ | |
JobGraphDsl::Job(Job::ContextOperation { | |
key: "data".to_string(), | |
value: Some("important_value".to_string()), | |
}), | |
JobGraphDsl::Job(Job::DataProcessing { | |
data: vec![1, 2, 3], | |
}), | |
JobGraphDsl::Job(Job::NetworkRequest { | |
url: "https://example.com".to_string(), | |
}), | |
])), | |
); | |
job_graph.from_dsl(job_dsl); | |
// Create a shared log | |
let log = Arc::new(Mutex::new(Vec::new())); | |
// Create and run the JobRunner | |
let runner = JobRunner::new(Arc::new(job_graph), Arc::clone(&log)); | |
runner.run().await; | |
// Print the execution log | |
println!("Execution Log:"); | |
let log = log.lock().unwrap(); | |
for (index, entry) in log.iter().enumerate() { | |
println!("{}. {}", index + 1, entry); | |
} | |
// Access the final state of the shared context | |
let context = runner.context.clone(); | |
if let Some(value) = context.get("data") { | |
println!("Final value in shared context: data = {}", value); | |
} else { | |
println!("No value found for 'data' in shared context"); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment