Skip to content

Instantly share code, notes, and snippets.

@WomB0ComB0
Created November 15, 2024 19:04
Show Gist options
  • Save WomB0ComB0/c630fe59ec83418fb6e67fad6caa7758 to your computer and use it in GitHub Desktop.
Save WomB0ComB0/c630fe59ec83418fb6e67fad6caa7758 to your computer and use it in GitHub Desktop.
Rust multi-threading template (will apply opinionated changes soon)
use std::{
collections::HashMap,
sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
Arc, Mutex, RwLock,
},
thread,
time::{Duration, Instant},
};
use tokio::{
sync::{mpsc, oneshot},
task,
time::timeout,
};
use futures::future::join_all;
use thiserror::Error;
use tracing::{debug, error, info, warn};
// Custom error types
#[derive(Error, Debug)]
pub enum ThreadPoolError {
#[error("Task execution failed: {0}")]
TaskError(String),
#[error("Pool is shutting down")]
ShuttingDown,
#[error("Task timed out")]
Timeout,
#[error("Channel send error: {0}")]
ChannelError(String),
}
// Task and result types
#[derive(Debug)]
pub struct Task<T> {
id: String,
data: T,
created_at: Instant,
}
#[derive(Debug)]
pub struct TaskResult<T> {
task_id: String,
result: Result<T, ThreadPoolError>,
processing_time: Duration,
}
// Thread pool implementation using standard threads
pub struct ThreadPool<T, R>
where
T: Send + 'static,
R: Send + 'static,
{
workers: Vec<Worker>,
sender: mpsc::Sender<Message<T, R>>,
size: usize,
active_tasks: Arc<AtomicUsize>,
shutdown: Arc<AtomicBool>,
}
impl<T, R> ThreadPool<T, R>
where
T: Send + 'static,
R: Send + 'static,
{
pub fn new(size: usize) -> Self {
let (sender, receiver) = mpsc::channel(size * 2);
let receiver = Arc::new(Mutex::new(receiver));
let mut workers = Vec::with_capacity(size);
let active_tasks = Arc::new(AtomicUsize::new(0));
let shutdown = Arc::new(AtomicBool::new(false));
for id in 0..size {
workers.push(Worker::new(
id,
Arc::clone(&receiver),
Arc::clone(&active_tasks),
Arc::clone(&shutdown),
));
}
ThreadPool {
workers,
sender,
size,
active_tasks,
shutdown,
}
}
pub async fn execute<F>(&self, task: Task<T>) -> Result<R, ThreadPoolError>
where
F: FnOnce(T) -> Result<R, ThreadPoolError> + Send + 'static,
{
if self.shutdown.load(Ordering::SeqCst) {
return Err(ThreadPoolError::ShuttingDown);
}
let (response_tx, response_rx) = oneshot::channel();
let message = Message::NewTask {
task,
response: response_tx,
handler: Box::new(move |data| {
F::call_once((F,), (data,))
}),
};
self.sender
.send(message)
.await
.map_err(|e| ThreadPoolError::ChannelError(e.to_string()))?;
self.active_tasks.fetch_add(1, Ordering::SeqCst);
response_rx
.await
.map_err(|e| ThreadPoolError::ChannelError(e.to_string()))?
}
pub async fn shutdown(&self) {
self.shutdown.store(true, Ordering::SeqCst);
// Wait for active tasks to complete
while self.active_tasks.load(Ordering::SeqCst) > 0 {
tokio::time::sleep(Duration::from_millis(100)).await;
}
for _ in 0..self.size {
if let Err(e) = self.sender.send(Message::Terminate).await {
error!("Error sending terminate message: {}", e);
}
}
}
pub fn active_tasks(&self) -> usize {
self.active_tasks.load(Ordering::SeqCst)
}
}
// Worker implementation
struct Worker {
id: usize,
thread: Option<thread::JoinHandle<()>>,
}
impl Worker {
fn new(
id: usize,
receiver: Arc<Mutex<mpsc::Receiver<Message<T, R>>>>,
active_tasks: Arc<AtomicUsize>,
shutdown: Arc<AtomicBool>,
) -> Worker {
let thread = thread::spawn(move || {
debug!("Worker {} started", id);
loop {
let message = {
let receiver = receiver.lock().unwrap();
receiver.blocking_recv()
};
match message {
Some(Message::NewTask {
task,
response,
handler,
}) => {
let start_time = Instant::now();
let result = handler(task.data);
let processing_time = start_time.elapsed();
let task_result = TaskResult {
task_id: task.id,
result,
processing_time,
};
if let Err(e) = response.send(task_result) {
error!("Worker {}: Error sending result: {}", id, e);
}
active_tasks.fetch_sub(1, Ordering::SeqCst);
}
Some(Message::Terminate) => {
debug!("Worker {} terminating", id);
break;
}
None => {
if shutdown.load(Ordering::SeqCst) {
break;
}
}
}
}
});
Worker {
id,
thread: Some(thread),
}
}
}
// Async task processor implementation
pub struct AsyncTaskProcessor<T, R> {
concurrency_limit: usize,
results: Arc<RwLock<HashMap<String, TaskResult<R>>>>,
}
impl<T, R> AsyncTaskProcessor<T, R>
where
T: Send + Sync + 'static,
R: Send + Sync + 'static,
{
pub fn new(concurrency_limit: usize) -> Self {
AsyncTaskProcessor {
concurrency_limit,
results: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn process_batch<F>(
&self,
tasks: Vec<Task<T>>,
timeout_duration: Duration,
f: F,
) -> Result<Vec<TaskResult<R>>, ThreadPoolError>
where
F: Fn(T) -> Result<R, ThreadPoolError> + Send + Sync + Clone + 'static,
{
let mut futures = Vec::with_capacity(tasks.len());
let semaphore = Arc::new(tokio::sync::Semaphore::new(self.concurrency_limit));
for task in tasks {
let permit = semaphore.clone().acquire_owned().await.unwrap();
let f = f.clone();
let future = task::spawn(async move {
let start_time = Instant::now();
let result = f(task.data);
let processing_time = start_time.elapsed();
let task_result = TaskResult {
task_id: task.id,
result,
processing_time,
};
drop(permit);
task_result
});
futures.push(future);
}
let timeout_future = timeout(timeout_duration, join_all(futures));
match timeout_future.await {
Ok(results) => {
let mut processed_results = Vec::new();
for result in results {
match result {
Ok(task_result) => {
processed_results.push(task_result);
}
Err(e) => {
error!("Task execution error: {}", e);
return Err(ThreadPoolError::TaskError(e.to_string()));
}
}
}
Ok(processed_results)
}
Err(_) => Err(ThreadPoolError::Timeout),
}
}
}
// Example usage
#[tokio::main]
async fn main() {
// Initialize tracing
tracing_subscriber::fmt::init();
// Create thread pool
let pool: ThreadPool<String, String> = ThreadPool::new(4);
// Example task
let task = Task {
id: "task-1".to_string(),
data: "Hello, World!".to_string(),
created_at: Instant::now(),
};
// Execute task
let result = pool
.execute(task, |data| {
Ok(format!("Processed: {}", data))
})
.await;
match result {
Ok(output) => println!("Task completed: {}", output),
Err(e) => eprintln!("Task failed: {}", e),
}
// Create async task processor
let processor: AsyncTaskProcessor<String, String> = AsyncTaskProcessor::new(4);
// Create batch of tasks
let tasks = vec![
Task {
id: "1".to_string(),
data: "Task 1".to_string(),
created_at: Instant::now(),
},
Task {
id: "2".to_string(),
data: "Task 2".to_string(),
created_at: Instant::now(),
},
];
// Process batch
let results = processor
.process_batch(
tasks,
Duration::from_secs(10),
|data| Ok(format!("Processed: {}", data)),
)
.await;
match results {
Ok(results) => {
for result in results {
println!(
"Task {} completed in {:?}: {:?}",
result.task_id,
result.processing_time,
result.result
);
}
}
Err(e) => eprintln!("Batch processing failed: {}", e),
}
// Shutdown pool
pool.shutdown().await;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment