Skip to content

Instantly share code, notes, and snippets.

@RandyMcMillan
Created November 10, 2024 22:48
Show Gist options
  • Save RandyMcMillan/fdf6a6fe7b0f71ab665e58a8282865dc to your computer and use it in GitHub Desktop.
Save RandyMcMillan/fdf6a6fe7b0f71ab665e58a8282865dc to your computer and use it in GitHub Desktop.
rust_async_task-1990/869774/428654
//Cargo.toml
//[package]
//name = "async-task"
//version = "0.1.0"
//edition = "2021"
//
//[dependencies]
//env_logger = "0.8.2"
//log = "0.4.13"
//tokio = { version = "1", features = ["full"] }
use log::{info, trace};
use std::collections::VecDeque;
use std::sync::{mpsc, Arc, Mutex};
use std::time::Duration;
use tokio;
/// Create a WorkQueue of any type that holds all the work to be done
#[derive(Clone)]
struct WorkQueue<T> {
queue: Arc<Mutex<VecDeque<T>>>,
}
impl<T> WorkQueue<T> {
/// Create a new empty queue
fn new() -> Self {
Self {
queue: Arc::new(Mutex::new(VecDeque::new())),
}
}
/// Add work to the queue
fn add_work(&self, work: T) -> Result<(), ()> {
let queue = self.queue.lock();
if let Ok(mut q) = queue {
q.push_back(work);
Ok(())
} else {
Err(())
}
}
/// Get the first available work
fn get_work(&self) -> Option<T> {
// Lock the queue to fetch a work to do and prevent other threads from
// fetching the same work.
let queue = self.queue.lock();
if let Ok(mut q) = queue {
// Remove the first work available
// Follows the the FIFO layout
q.pop_front()
} else {
None
}
}
/// Count the work left
fn length(&self) -> Option<usize> {
let queue = self.queue.lock();
if let Ok(q) = queue {
Some(q.len())
} else {
None
}
}
}
/// A very complex calculation that takes too much time to execute
async fn calculate_y(y: i32, duration: u64) -> i32 {
trace!("worker_y:calculate");
// Use tokio::time::sleep instead of thread::sleep to avoid blocking the
// entire thread.
tokio::time::sleep(Duration::from_millis(duration)).await;
y * 2
}
/// A very complex calculation that takes too much time to execute
async fn calculate_x(x: i32, duration: u64) -> i32 {
trace!("worker_x:calculate");
// Use tokio::time::sleep instead of thread::sleep to avoid blocking the
// entire thread.
tokio::time::sleep(Duration::from_millis(duration)).await;
x * 2
}
async fn create_worker_y(
i: u32,
queue_clone: WorkQueue<i32>,
max_work_async: i32,
tx_clone: mpsc::Sender<i32>,
) {
// How much work has this thread done
let mut work_done: i32 = 0;
let mut current_work: i32 = 0;
// Check if there is more work to be done
while queue_clone.length().unwrap() > 0 {
trace!("worker_y:check_work_avail");
let mut tasks = Vec::new();
while current_work < max_work_async {
if let Some(work) = queue_clone.get_work() {
trace!("worker_y:get_work");
let task = tokio::task::spawn(calculate_x(work, 1000));
tasks.push(task);
work_done += 1;
current_work += 1;
} else {
break;
}
}
trace!("worker_y:wait_for_task_completion");
for task in tasks {
let result = task.await.unwrap();
tx_clone.send(result).unwrap();
}
current_work = 0;
}
trace!(
"worker_y:thread:{:?}:work_done:{:?}",
i,
work_done
);
}
async fn create_worker_x(
i: u32,
queue_clone: WorkQueue<i32>,
max_work_async: i32,
tx_clone: mpsc::Sender<i32>,
) {
// How much work has this thread done
let mut work_done: i32 = 0;
let mut current_work: i32 = 0;
// Check if there is more work to be done
while queue_clone.length().unwrap() > 0 {
trace!("worker_x:check_work");
let mut tasks = Vec::new();
while current_work < max_work_async {
if let Some(work) = queue_clone.get_work() {
trace!("worker_x:get_work");
let task = tokio::task::spawn(calculate_y(work, 1000));
tasks.push(task);
work_done += 1;
current_work += 1;
} else {
break;
}
}
trace!("worker_x:wait_for_task_completion");
for task in tasks {
let result = task.await.unwrap();
tx_clone.send(result).unwrap();
}
current_work = 0;
}
trace!(
"worker_x:thread:{:?}:work_done:{:?}",
i,
work_done
);
}
#[tokio::main]
async fn main() {
// Dont fonrget to set the environment variable
// $env:RUST_LOG="TRACE"
env_logger::init();
log::warn!("[root] warn");
log::info!("[root] info");
log::debug!("[root] debug");
info!("Start");
// Store the result of the calculations
let mut data: Vec<i32> = Vec::new();
// Create a channel to receive data from the calculations.
let (tx, rx) = mpsc::channel();
let (ty, ry) = mpsc::channel();
// Set the maximum number of threads.
// Create 12 threads as my CPU has 12 logical cores
let total_threads: u32 = 12;
// Set a maximum amount of work that a thread can do async
let max_work_async = 10;
// Set the amount of work to do
let work = 240;
// Keep track of how much work has to be done
let mut work_remaining = 0;
// Create a new work queue
let queue = WorkQueue::new();
for i in 0..work {
queue.add_work(i).unwrap();
work_remaining += 1;
}
println!("work to be done:{:?}", queue.length());
// Store the handles of all the threads
let mut handles = Vec::new();
for i in 0..total_threads {
let tx_clone = tx.clone();
let ty_clone = ty.clone();
// This is just a reference to the queue as the queue is a Arc Mutex
let queue_clone = queue.clone();
trace!("Create Worker");
let h = tokio::spawn(create_worker_x(
i,
queue_clone.clone(),
max_work_async,
tx_clone.clone(),
));
handles.push(h);
let h = tokio::spawn(create_worker_y(
i,
queue_clone.clone(),
max_work_async,
ty_clone.clone(),
));
handles.push(h);
}
trace!("Poll the results");
// Keep receiving until all the work has been done
while work_remaining > 0 {
match rx.recv() {
Ok(result) => {
data.push(result);
work_remaining -= 1;
}
Err(_) => {}
}
trace!("worker_x:work_remaining:{}", work_remaining);
match ry.recv() {
Ok(result) => {
data.push(result);
work_remaining -= 1;
}
Err(_) => {}
}
trace!("worker_y:work_remaining:{}", work_remaining);
}
// Make sure all the threads have finished
for h in handles {
h.await.unwrap();
}
// Check that all work has been done correctly
let mut total = 0;
for i in data {
total += i;
}
let mut expected_total = 0;
for i in 0..work {
expected_total += i * 2;
}
trace!("Expected: {:?}, Result: {:?}", expected_total, total);
info!("End");
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment