Skip to content

Instantly share code, notes, and snippets.

@edgarriba
Created October 24, 2024 18:22
Show Gist options
  • Save edgarriba/4d99beee0515d7fb6150fd352c171dc0 to your computer and use it in GitHub Desktop.
Save edgarriba/4d99beee0515d7fb6150fd352c171dc0 to your computer and use it in GitHub Desktop.
data pipelines in rust with zenoh
use std::{
sync::{Arc, Mutex},
time::Duration,
};
use zenoh::Wait;
#[derive(thiserror::Error, Debug)]
pub enum NodeError {
#[error("Node error: {0}")]
NodeError(String),
}
pub trait Node: Send + Sync {
fn iter(&self, zenoh_session: Arc<zenoh::Session>) -> Result<(), NodeError>;
fn start(&self, zenoh_session: Arc<zenoh::Session>) -> Result<(), NodeError>;
fn stop(&self) -> Result<(), NodeError>;
fn name(&self) -> String;
fn sleep_time(&self) -> u64;
}
pub struct NodeStats {
pub start_time: std::time::Instant,
pub end_time: std::time::Instant,
pub num_iterations: u64,
pub min_iteration_time: u64,
pub max_iteration_time: u64,
pub avg_iteration_time: f32,
}
impl NodeStats {
pub fn new() -> Self {
Self {
start_time: std::time::Instant::now(),
end_time: std::time::Instant::now(),
num_iterations: 0,
min_iteration_time: u64::MAX,
max_iteration_time: 0,
avg_iteration_time: 0.0,
}
}
pub fn update(&mut self, iter_time: Duration) {
self.num_iterations += 1;
self.end_time = std::time::Instant::now();
self.min_iteration_time = self.min_iteration_time.min(iter_time.as_millis() as u64);
self.max_iteration_time = self.max_iteration_time.max(iter_time.as_millis() as u64);
self.avg_iteration_time = (self.avg_iteration_time * (self.num_iterations - 1) as f32
+ iter_time.as_millis() as f32)
/ self.num_iterations as f32;
}
}
pub struct NodeWorker {
node: Arc<dyn Node>,
running: Arc<std::sync::atomic::AtomicBool>,
handle: std::thread::JoinHandle<()>,
stats: Arc<Mutex<NodeStats>>,
}
impl NodeWorker {
pub fn new(
node: Arc<dyn Node>,
_name: String,
sleep_time: u64,
zenoh_session: Arc<zenoh::Session>,
) -> Result<Self, NodeError> {
let running = Arc::new(std::sync::atomic::AtomicBool::new(true));
let stats = NodeStats::new();
let stats = Arc::new(Mutex::new(stats));
// the handle is the thread that will run the node
let zenoh_session_clone = zenoh_session.clone();
let handle = std::thread::spawn({
let node = node.clone();
let running = running.clone();
let zenoh_session = zenoh_session_clone.clone();
let stats = stats.clone();
move || {
// start the node
if let Err(e) = node.start(zenoh_session.clone()) {
log::error!("Node worker start error: {}", e);
return;
}
// run the node until something goes wrong or termination is signaled
while running.load(std::sync::atomic::Ordering::Relaxed) {
let start = std::time::Instant::now();
if let Err(e) = node.iter(zenoh_session.clone()) {
log::error!("Node worker run error: {}", e);
break;
}
let elapsed = start.elapsed();
stats.lock().unwrap().update(elapsed); // update the stats
if elapsed < std::time::Duration::from_millis(sleep_time) {
std::thread::sleep(std::time::Duration::from_millis(
sleep_time - elapsed.as_millis() as u64,
));
}
}
// stop the node
if let Err(e) = node.stop() {
log::error!("Node worker stop error: {}", e);
return;
}
}
});
Ok(Self {
node,
handle,
running,
stats,
})
}
pub fn set_running(&self, running: bool) {
self.running
.store(running, std::sync::atomic::Ordering::Relaxed);
}
pub fn join(self) -> Result<(), Box<dyn std::error::Error>> {
if let Err(_e) = self.handle.join() {
log::error!("Node worker join error");
}
log::info!("Node {} joined", self.node.name());
Ok(())
}
}
pub struct Runtime {
scheduled_nodes: Vec<Arc<dyn Node>>,
running_nodes: Vec<NodeWorker>,
zenoh_session: Arc<zenoh::Session>,
}
impl Runtime {
pub fn new(zenoh_session: Arc<zenoh::Session>) -> Self {
Self {
scheduled_nodes: Vec::new(),
running_nodes: Vec::new(),
zenoh_session,
}
}
pub fn schedule(&mut self, node: Arc<dyn Node>) -> Result<(), NodeError> {
self.scheduled_nodes.push(node);
Ok(())
}
pub fn run(&mut self) -> Result<(), NodeError> {
// spawn a task for each node
while let Some(node) = self.scheduled_nodes.pop() {
let name = node.name();
let sleep_time = node.sleep_time();
self.running_nodes.push(NodeWorker::new(
node,
name,
sleep_time,
self.zenoh_session.clone(),
)?);
}
Ok(())
}
pub fn terminate(&mut self) -> Result<(), NodeError> {
// build the stats table
let mut table = prettytable::Table::new();
table.add_row(prettytable::Row::new(vec![
prettytable::Cell::new("Node"),
prettytable::Cell::new("Min Time (ms)"),
prettytable::Cell::new("Avg Time (ms)"),
prettytable::Cell::new("Max Time (ms)"),
prettytable::Cell::new("Iterations"),
]));
// stop all the nodes
for node in self.running_nodes.iter() {
// set the running flag to false to stop the node
node.set_running(false);
let stats = node.stats.lock().unwrap();
table.add_row(prettytable::Row::new(vec![
prettytable::Cell::new(&node.node.name()),
prettytable::Cell::new(&format!("{:.2}", stats.min_iteration_time as f32)),
prettytable::Cell::new(&format!("{:.2}", stats.avg_iteration_time)),
prettytable::Cell::new(&format!("{:.2}", stats.max_iteration_time as f32)),
prettytable::Cell::new(&stats.num_iterations.to_string()),
]));
}
// wait for all the nodes to finish
while let Some(node) = self.running_nodes.pop() {
if let Err(e) = node.join() {
// TODO: this is critical, we should not exit with an error
log::error!("Node worker join error: {}", e);
}
}
// close the zenoh session
self.zenoh_session
.close()
.wait()
.map_err(|e| NodeError::NodeError(e.to_string()))?;
table.printstd();
Ok(())
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment