Created
November 10, 2024 22:48
-
-
Save RandyMcMillan/fdf6a6fe7b0f71ab665e58a8282865dc to your computer and use it in GitHub Desktop.
rust_async_task-1990/869774/428654
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//Cargo.toml | |
//[package] | |
//name = "async-task" | |
//version = "0.1.0" | |
//edition = "2021" | |
// | |
//[dependencies] | |
//env_logger = "0.8.2" | |
//log = "0.4.13" | |
//tokio = { version = "1", features = ["full"] } | |
use log::{info, trace}; | |
use std::collections::VecDeque; | |
use std::sync::{mpsc, Arc, Mutex}; | |
use std::time::Duration; | |
use tokio; | |
/// Create a WorkQueue of any type that holds all the work to be done | |
#[derive(Clone)] | |
struct WorkQueue<T> { | |
queue: Arc<Mutex<VecDeque<T>>>, | |
} | |
impl<T> WorkQueue<T> { | |
/// Create a new empty queue | |
fn new() -> Self { | |
Self { | |
queue: Arc::new(Mutex::new(VecDeque::new())), | |
} | |
} | |
/// Add work to the queue | |
fn add_work(&self, work: T) -> Result<(), ()> { | |
let queue = self.queue.lock(); | |
if let Ok(mut q) = queue { | |
q.push_back(work); | |
Ok(()) | |
} else { | |
Err(()) | |
} | |
} | |
/// Get the first available work | |
fn get_work(&self) -> Option<T> { | |
// Lock the queue to fetch a work to do and prevent other threads from | |
// fetching the same work. | |
let queue = self.queue.lock(); | |
if let Ok(mut q) = queue { | |
// Remove the first work available | |
// Follows the the FIFO layout | |
q.pop_front() | |
} else { | |
None | |
} | |
} | |
/// Count the work left | |
fn length(&self) -> Option<usize> { | |
let queue = self.queue.lock(); | |
if let Ok(q) = queue { | |
Some(q.len()) | |
} else { | |
None | |
} | |
} | |
} | |
/// A very complex calculation that takes too much time to execute | |
async fn calculate_y(y: i32, duration: u64) -> i32 { | |
trace!("worker_y:calculate"); | |
// Use tokio::time::sleep instead of thread::sleep to avoid blocking the | |
// entire thread. | |
tokio::time::sleep(Duration::from_millis(duration)).await; | |
y * 2 | |
} | |
/// A very complex calculation that takes too much time to execute | |
async fn calculate_x(x: i32, duration: u64) -> i32 { | |
trace!("worker_x:calculate"); | |
// Use tokio::time::sleep instead of thread::sleep to avoid blocking the | |
// entire thread. | |
tokio::time::sleep(Duration::from_millis(duration)).await; | |
x * 2 | |
} | |
async fn create_worker_y( | |
i: u32, | |
queue_clone: WorkQueue<i32>, | |
max_work_async: i32, | |
tx_clone: mpsc::Sender<i32>, | |
) { | |
// How much work has this thread done | |
let mut work_done: i32 = 0; | |
let mut current_work: i32 = 0; | |
// Check if there is more work to be done | |
while queue_clone.length().unwrap() > 0 { | |
trace!("worker_y:check_work_avail"); | |
let mut tasks = Vec::new(); | |
while current_work < max_work_async { | |
if let Some(work) = queue_clone.get_work() { | |
trace!("worker_y:get_work"); | |
let task = tokio::task::spawn(calculate_x(work, 1000)); | |
tasks.push(task); | |
work_done += 1; | |
current_work += 1; | |
} else { | |
break; | |
} | |
} | |
trace!("worker_y:wait_for_task_completion"); | |
for task in tasks { | |
let result = task.await.unwrap(); | |
tx_clone.send(result).unwrap(); | |
} | |
current_work = 0; | |
} | |
trace!( | |
"worker_y:thread:{:?}:work_done:{:?}", | |
i, | |
work_done | |
); | |
} | |
async fn create_worker_x( | |
i: u32, | |
queue_clone: WorkQueue<i32>, | |
max_work_async: i32, | |
tx_clone: mpsc::Sender<i32>, | |
) { | |
// How much work has this thread done | |
let mut work_done: i32 = 0; | |
let mut current_work: i32 = 0; | |
// Check if there is more work to be done | |
while queue_clone.length().unwrap() > 0 { | |
trace!("worker_x:check_work"); | |
let mut tasks = Vec::new(); | |
while current_work < max_work_async { | |
if let Some(work) = queue_clone.get_work() { | |
trace!("worker_x:get_work"); | |
let task = tokio::task::spawn(calculate_y(work, 1000)); | |
tasks.push(task); | |
work_done += 1; | |
current_work += 1; | |
} else { | |
break; | |
} | |
} | |
trace!("worker_x:wait_for_task_completion"); | |
for task in tasks { | |
let result = task.await.unwrap(); | |
tx_clone.send(result).unwrap(); | |
} | |
current_work = 0; | |
} | |
trace!( | |
"worker_x:thread:{:?}:work_done:{:?}", | |
i, | |
work_done | |
); | |
} | |
#[tokio::main] | |
async fn main() { | |
// Dont fonrget to set the environment variable | |
// $env:RUST_LOG="TRACE" | |
env_logger::init(); | |
log::warn!("[root] warn"); | |
log::info!("[root] info"); | |
log::debug!("[root] debug"); | |
info!("Start"); | |
// Store the result of the calculations | |
let mut data: Vec<i32> = Vec::new(); | |
// Create a channel to receive data from the calculations. | |
let (tx, rx) = mpsc::channel(); | |
let (ty, ry) = mpsc::channel(); | |
// Set the maximum number of threads. | |
// Create 12 threads as my CPU has 12 logical cores | |
let total_threads: u32 = 12; | |
// Set a maximum amount of work that a thread can do async | |
let max_work_async = 10; | |
// Set the amount of work to do | |
let work = 240; | |
// Keep track of how much work has to be done | |
let mut work_remaining = 0; | |
// Create a new work queue | |
let queue = WorkQueue::new(); | |
for i in 0..work { | |
queue.add_work(i).unwrap(); | |
work_remaining += 1; | |
} | |
println!("work to be done:{:?}", queue.length()); | |
// Store the handles of all the threads | |
let mut handles = Vec::new(); | |
for i in 0..total_threads { | |
let tx_clone = tx.clone(); | |
let ty_clone = ty.clone(); | |
// This is just a reference to the queue as the queue is a Arc Mutex | |
let queue_clone = queue.clone(); | |
trace!("Create Worker"); | |
let h = tokio::spawn(create_worker_x( | |
i, | |
queue_clone.clone(), | |
max_work_async, | |
tx_clone.clone(), | |
)); | |
handles.push(h); | |
let h = tokio::spawn(create_worker_y( | |
i, | |
queue_clone.clone(), | |
max_work_async, | |
ty_clone.clone(), | |
)); | |
handles.push(h); | |
} | |
trace!("Poll the results"); | |
// Keep receiving until all the work has been done | |
while work_remaining > 0 { | |
match rx.recv() { | |
Ok(result) => { | |
data.push(result); | |
work_remaining -= 1; | |
} | |
Err(_) => {} | |
} | |
trace!("worker_x:work_remaining:{}", work_remaining); | |
match ry.recv() { | |
Ok(result) => { | |
data.push(result); | |
work_remaining -= 1; | |
} | |
Err(_) => {} | |
} | |
trace!("worker_y:work_remaining:{}", work_remaining); | |
} | |
// Make sure all the threads have finished | |
for h in handles { | |
h.await.unwrap(); | |
} | |
// Check that all work has been done correctly | |
let mut total = 0; | |
for i in data { | |
total += i; | |
} | |
let mut expected_total = 0; | |
for i in 0..work { | |
expected_total += i * 2; | |
} | |
trace!("Expected: {:?}, Result: {:?}", expected_total, total); | |
info!("End"); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment