Created
November 15, 2024 19:04
-
-
Save WomB0ComB0/c630fe59ec83418fb6e67fad6caa7758 to your computer and use it in GitHub Desktop.
Rust multi-threading template (will apply opinionated changes soon)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::{ | |
collections::HashMap, | |
sync::{ | |
atomic::{AtomicBool, AtomicUsize, Ordering}, | |
Arc, Mutex, RwLock, | |
}, | |
thread, | |
time::{Duration, Instant}, | |
}; | |
use tokio::{ | |
sync::{mpsc, oneshot}, | |
task, | |
time::timeout, | |
}; | |
use futures::future::join_all; | |
use thiserror::Error; | |
use tracing::{debug, error, info, warn}; | |
// Custom error types | |
#[derive(Error, Debug)] | |
pub enum ThreadPoolError { | |
#[error("Task execution failed: {0}")] | |
TaskError(String), | |
#[error("Pool is shutting down")] | |
ShuttingDown, | |
#[error("Task timed out")] | |
Timeout, | |
#[error("Channel send error: {0}")] | |
ChannelError(String), | |
} | |
// Task and result types | |
#[derive(Debug)] | |
pub struct Task<T> { | |
id: String, | |
data: T, | |
created_at: Instant, | |
} | |
#[derive(Debug)] | |
pub struct TaskResult<T> { | |
task_id: String, | |
result: Result<T, ThreadPoolError>, | |
processing_time: Duration, | |
} | |
// Thread pool implementation using standard threads | |
pub struct ThreadPool<T, R> | |
where | |
T: Send + 'static, | |
R: Send + 'static, | |
{ | |
workers: Vec<Worker>, | |
sender: mpsc::Sender<Message<T, R>>, | |
size: usize, | |
active_tasks: Arc<AtomicUsize>, | |
shutdown: Arc<AtomicBool>, | |
} | |
impl<T, R> ThreadPool<T, R> | |
where | |
T: Send + 'static, | |
R: Send + 'static, | |
{ | |
pub fn new(size: usize) -> Self { | |
let (sender, receiver) = mpsc::channel(size * 2); | |
let receiver = Arc::new(Mutex::new(receiver)); | |
let mut workers = Vec::with_capacity(size); | |
let active_tasks = Arc::new(AtomicUsize::new(0)); | |
let shutdown = Arc::new(AtomicBool::new(false)); | |
for id in 0..size { | |
workers.push(Worker::new( | |
id, | |
Arc::clone(&receiver), | |
Arc::clone(&active_tasks), | |
Arc::clone(&shutdown), | |
)); | |
} | |
ThreadPool { | |
workers, | |
sender, | |
size, | |
active_tasks, | |
shutdown, | |
} | |
} | |
pub async fn execute<F>(&self, task: Task<T>) -> Result<R, ThreadPoolError> | |
where | |
F: FnOnce(T) -> Result<R, ThreadPoolError> + Send + 'static, | |
{ | |
if self.shutdown.load(Ordering::SeqCst) { | |
return Err(ThreadPoolError::ShuttingDown); | |
} | |
let (response_tx, response_rx) = oneshot::channel(); | |
let message = Message::NewTask { | |
task, | |
response: response_tx, | |
handler: Box::new(move |data| { | |
F::call_once((F,), (data,)) | |
}), | |
}; | |
self.sender | |
.send(message) | |
.await | |
.map_err(|e| ThreadPoolError::ChannelError(e.to_string()))?; | |
self.active_tasks.fetch_add(1, Ordering::SeqCst); | |
response_rx | |
.await | |
.map_err(|e| ThreadPoolError::ChannelError(e.to_string()))? | |
} | |
pub async fn shutdown(&self) { | |
self.shutdown.store(true, Ordering::SeqCst); | |
// Wait for active tasks to complete | |
while self.active_tasks.load(Ordering::SeqCst) > 0 { | |
tokio::time::sleep(Duration::from_millis(100)).await; | |
} | |
for _ in 0..self.size { | |
if let Err(e) = self.sender.send(Message::Terminate).await { | |
error!("Error sending terminate message: {}", e); | |
} | |
} | |
} | |
pub fn active_tasks(&self) -> usize { | |
self.active_tasks.load(Ordering::SeqCst) | |
} | |
} | |
// Worker implementation | |
struct Worker { | |
id: usize, | |
thread: Option<thread::JoinHandle<()>>, | |
} | |
impl Worker { | |
fn new( | |
id: usize, | |
receiver: Arc<Mutex<mpsc::Receiver<Message<T, R>>>>, | |
active_tasks: Arc<AtomicUsize>, | |
shutdown: Arc<AtomicBool>, | |
) -> Worker { | |
let thread = thread::spawn(move || { | |
debug!("Worker {} started", id); | |
loop { | |
let message = { | |
let receiver = receiver.lock().unwrap(); | |
receiver.blocking_recv() | |
}; | |
match message { | |
Some(Message::NewTask { | |
task, | |
response, | |
handler, | |
}) => { | |
let start_time = Instant::now(); | |
let result = handler(task.data); | |
let processing_time = start_time.elapsed(); | |
let task_result = TaskResult { | |
task_id: task.id, | |
result, | |
processing_time, | |
}; | |
if let Err(e) = response.send(task_result) { | |
error!("Worker {}: Error sending result: {}", id, e); | |
} | |
active_tasks.fetch_sub(1, Ordering::SeqCst); | |
} | |
Some(Message::Terminate) => { | |
debug!("Worker {} terminating", id); | |
break; | |
} | |
None => { | |
if shutdown.load(Ordering::SeqCst) { | |
break; | |
} | |
} | |
} | |
} | |
}); | |
Worker { | |
id, | |
thread: Some(thread), | |
} | |
} | |
} | |
// Async task processor implementation | |
pub struct AsyncTaskProcessor<T, R> { | |
concurrency_limit: usize, | |
results: Arc<RwLock<HashMap<String, TaskResult<R>>>>, | |
} | |
impl<T, R> AsyncTaskProcessor<T, R> | |
where | |
T: Send + Sync + 'static, | |
R: Send + Sync + 'static, | |
{ | |
pub fn new(concurrency_limit: usize) -> Self { | |
AsyncTaskProcessor { | |
concurrency_limit, | |
results: Arc::new(RwLock::new(HashMap::new())), | |
} | |
} | |
pub async fn process_batch<F>( | |
&self, | |
tasks: Vec<Task<T>>, | |
timeout_duration: Duration, | |
f: F, | |
) -> Result<Vec<TaskResult<R>>, ThreadPoolError> | |
where | |
F: Fn(T) -> Result<R, ThreadPoolError> + Send + Sync + Clone + 'static, | |
{ | |
let mut futures = Vec::with_capacity(tasks.len()); | |
let semaphore = Arc::new(tokio::sync::Semaphore::new(self.concurrency_limit)); | |
for task in tasks { | |
let permit = semaphore.clone().acquire_owned().await.unwrap(); | |
let f = f.clone(); | |
let future = task::spawn(async move { | |
let start_time = Instant::now(); | |
let result = f(task.data); | |
let processing_time = start_time.elapsed(); | |
let task_result = TaskResult { | |
task_id: task.id, | |
result, | |
processing_time, | |
}; | |
drop(permit); | |
task_result | |
}); | |
futures.push(future); | |
} | |
let timeout_future = timeout(timeout_duration, join_all(futures)); | |
match timeout_future.await { | |
Ok(results) => { | |
let mut processed_results = Vec::new(); | |
for result in results { | |
match result { | |
Ok(task_result) => { | |
processed_results.push(task_result); | |
} | |
Err(e) => { | |
error!("Task execution error: {}", e); | |
return Err(ThreadPoolError::TaskError(e.to_string())); | |
} | |
} | |
} | |
Ok(processed_results) | |
} | |
Err(_) => Err(ThreadPoolError::Timeout), | |
} | |
} | |
} | |
// Example usage | |
#[tokio::main] | |
async fn main() { | |
// Initialize tracing | |
tracing_subscriber::fmt::init(); | |
// Create thread pool | |
let pool: ThreadPool<String, String> = ThreadPool::new(4); | |
// Example task | |
let task = Task { | |
id: "task-1".to_string(), | |
data: "Hello, World!".to_string(), | |
created_at: Instant::now(), | |
}; | |
// Execute task | |
let result = pool | |
.execute(task, |data| { | |
Ok(format!("Processed: {}", data)) | |
}) | |
.await; | |
match result { | |
Ok(output) => println!("Task completed: {}", output), | |
Err(e) => eprintln!("Task failed: {}", e), | |
} | |
// Create async task processor | |
let processor: AsyncTaskProcessor<String, String> = AsyncTaskProcessor::new(4); | |
// Create batch of tasks | |
let tasks = vec![ | |
Task { | |
id: "1".to_string(), | |
data: "Task 1".to_string(), | |
created_at: Instant::now(), | |
}, | |
Task { | |
id: "2".to_string(), | |
data: "Task 2".to_string(), | |
created_at: Instant::now(), | |
}, | |
]; | |
// Process batch | |
let results = processor | |
.process_batch( | |
tasks, | |
Duration::from_secs(10), | |
|data| Ok(format!("Processed: {}", data)), | |
) | |
.await; | |
match results { | |
Ok(results) => { | |
for result in results { | |
println!( | |
"Task {} completed in {:?}: {:?}", | |
result.task_id, | |
result.processing_time, | |
result.result | |
); | |
} | |
} | |
Err(e) => eprintln!("Batch processing failed: {}", e), | |
} | |
// Shutdown pool | |
pool.shutdown().await; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment