Last active
February 11, 2021 16:43
-
-
Save benkay86/fbfc84babca9b0996d6aee66087e59c4 to your computer and use it in GitHub Desktop.
Sidestream collects items from a wrapped stream (or iterator) into a queue on a separate thread.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//! Module for creating sidestreams. A sidestream is a stream over another | |
//! stream or an iterator. The items from the enclosed stream or iterator are | |
//! collected into a queue on a separate thread. An optional count parameter is | |
//! incremented as each item is queued. The items are asynchronously dequeued | |
//! by the enclosing sidestream. Each item is ready to be dequeued as soon as | |
//! it is yielded by the enclosed stream or iterator; the sidestream does *not* | |
//! wait for the collection thread to join. This pattern is useful when you | |
//! need to know the total number of items in a stream, e.g. to display progress | |
//! when processing an iterator over a list of files. It is also useful for | |
//! converting iterators into asynchronous streams. | |
//! | |
//! Note that it is safe to drop the sidestream before the collection thread has | |
//! joined. In this case the collection thread will be gracefully cancelled. | |
//! | |
//! See [`SideStreamExtForIterator::sidestream_with_count()`] for iterators. | |
//! See [`SideStreamExtForStream::sidestream_with_count()`] for streams. | |
use futures_core::stream::Stream; | |
use futures_util::stream::StreamExt; | |
use std::pin::Pin; | |
use std::task::{Context, Poll}; | |
use std::sync::Arc; | |
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; | |
/// Stream produced by [`sidestream()`](SideStreamExtForStream::sidestream()) | |
/// [_et al_](SideStreamExtForIterator::sidestream()). | |
pub struct UnboundedSideStream<T> where T: Send { | |
rx: tokio::sync::mpsc::UnboundedReceiver<T>, | |
// Incremented each time message is sent over the channel. | |
count: Option<Arc<AtomicUsize>>, | |
// Incremented each time message is sent over channel and decremented each | |
// time item is dequeued from channel, represents number of items left in | |
// stream. | |
size: Arc<AtomicUsize>, | |
// Set to true when thread holding sending side of channel is joined. | |
joined: Arc<AtomicBool> | |
} | |
impl<T> UnboundedSideStream<T> where T: Send { | |
/// True if the collection thread for this sidestream has joined. | |
pub fn joined(&self) -> bool { | |
self.joined.load(Ordering::Relaxed) | |
} | |
/// Current value of the count, if any. | |
pub fn count(&self) -> Option<usize> { | |
match &self.count { | |
Some(count) => Some(count.load(Ordering::Relaxed)), | |
None => None | |
} | |
} | |
} | |
impl<T> Stream for UnboundedSideStream<T> where T: Send { | |
type Item = T; | |
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { | |
// Get mutable reference to self. | |
let this = self.get_mut(); | |
// Poll inner receiver. | |
match this.rx.poll_recv(cx) { | |
// Pass through pending value. | |
Poll::Pending => Poll::Pending, | |
// Decrement size if value is ready. | |
std::task::Poll::Ready(val) => { | |
this.size.fetch_sub(1, Ordering::Release); | |
std::task::Poll::Ready(val) | |
} | |
} | |
} | |
fn size_hint(&self) -> (usize, Option<usize>) { | |
match self.joined.load(Ordering::Relaxed) { | |
// Sending thread is joined, so size of stream will not increase. | |
// Upper bound Some(size) is therefore known. | |
true => { | |
let size = self.size.load(Ordering::Acquire); | |
(size, Some(size)) | |
}, | |
// Size of stream could increase, so upper bound None is unknown. | |
false => (self.size.load(Ordering::Acquire), None) | |
} | |
} | |
} | |
/// Combinator trait to convert iterators into sidestreams. | |
pub trait SideStreamExtForIterator { | |
type Item: Send + 'static; | |
/// Convert an iterator into an unbounded sidestream. If a count variable | |
/// is provided then it will be atomically incremented each time an item is | |
/// collected from the iterator. | |
fn sidestream_with_count(self, count: Option<Arc<AtomicUsize>>) -> UnboundedSideStream<Self::Item>; | |
/// Shortcut for calling `sidestream_with_count(None)`. | |
/// See [`sidestream_with_count()`](SideStreamExtForIterator::sidestream_with_count()). | |
fn sidestream(self) -> UnboundedSideStream<Self::Item> where Self: Sized { | |
self.sidestream_with_count(None) | |
} | |
} | |
impl<S, T> SideStreamExtForIterator for S | |
where | |
S: IntoIterator<Item = T> + Send + 'static, | |
T: Send + 'static | |
{ | |
type Item = T; | |
fn sidestream_with_count(self, count: Option<Arc<AtomicUsize>>) -> UnboundedSideStream<Self::Item> { | |
// Make the size and joined variables. | |
let size = Arc::new(AtomicUsize::new(0)); | |
let joined = Arc::new(AtomicBool::new(false)); | |
// Make the channel. | |
let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<Self::Item>(); | |
// Spawn a thread on the send side of the channel. | |
{ | |
// Clone atomic variables to be shared with thread. | |
let count = count.clone(); | |
let size = size.clone(); | |
let joined = joined.clone(); | |
// Spawn thread. | |
tokio::task::spawn_blocking(move || { | |
// Traverse the iterator. | |
for item in self { | |
// Send item over the channel. | |
if tx.send(item).is_err() { | |
// Cancel this thread if the stream was cancelled. | |
break; | |
} | |
// Increment size. | |
size.fetch_add(1, Ordering::Release); | |
// Increment count. | |
if let Some(count) = &count { | |
count.fetch_add(1, Ordering::Relaxed); | |
} | |
} | |
// Iterator is exhausted. About to join this thread. | |
joined.store(true, Ordering::Relaxed); | |
}); | |
} | |
// Return UnboundedSideStream holding the receive side of the channel. | |
UnboundedSideStream { rx, count, size, joined } | |
} | |
} | |
/// Combinator trait to convert streams into sidestreams. | |
pub trait SideStreamExtForStream { | |
type Item: Send + 'static; | |
/// Convert a stream into an unbounded sidestream. If a count variable | |
/// is provided then it will be atomically incremented each time an item | |
/// is collected from the iterator. | |
fn sidestream_with_count(self, count: Option<Arc<AtomicUsize>>) -> UnboundedSideStream<Self::Item>; | |
/// Shortcut for calling `sidestream_with_count(None)`. | |
/// See [`sidestream_with_count()`](SideStreamExtForStream::sidestream_with_count()). | |
fn sidestream(self) -> UnboundedSideStream<Self::Item> where Self: Sized { | |
self.sidestream_with_count(None) | |
} | |
} | |
impl<S, T> SideStreamExtForStream for S | |
where | |
S: Stream<Item = T> + Send + 'static, | |
T: Send + 'static | |
{ | |
type Item = T; | |
fn sidestream_with_count(self, count: Option<Arc<AtomicUsize>>) -> UnboundedSideStream<Self::Item> { | |
// Make the size and joined variables. | |
let size = Arc::new(AtomicUsize::new(0)); | |
let joined = Arc::new(AtomicBool::new(false)); | |
// Make the channel. | |
let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<Self::Item>(); | |
// Spawn a thread on the send side of the channel. | |
{ | |
// Clone atomic variables to be shared with thread. | |
let count = count.clone(); | |
let size = size.clone(); | |
let joined = joined.clone(); | |
// Spawn thread. | |
tokio::task::spawn(async move { | |
// Pin stream to this thread's stack. | |
let s = self; | |
tokio::pin!(s); | |
// Iterate over the elements of the stream. | |
while let Some(item) = (s.next()).await { | |
// Send item over the channel. | |
if tx.send(item).is_err() { | |
// Cancel this thread if the stream was cancelled. | |
break; | |
} | |
// Increment size. | |
size.fetch_add(1, Ordering::Release); | |
// Increment count. | |
if let Some(count) = &count { | |
count.fetch_add(1, Ordering::Relaxed); | |
} | |
} | |
// Iterator is exhausted. About to join this thread. | |
joined.store(true, Ordering::Relaxed); | |
}); | |
} | |
// Return UnboundedSideStream holding the receive side of the channel. | |
UnboundedSideStream { rx, count, size, joined } | |
} | |
} | |
#[cfg(test)] | |
mod tests { | |
use super::*; | |
#[tokio::test] | |
async fn test_unbounded_sidestream_iter() { | |
// Create sidestream from iterator over vector. | |
let v: Vec<i32> = vec![1, 2, 3]; | |
let s = v.into_iter().sidestream(); | |
// Collect sidestream and compare to vector. | |
assert!(s.collect::<Vec<i32>>().await == vec![1, 2, 3]); | |
} | |
#[tokio::test] | |
async fn test_unbounded_sidestream_stream() { | |
// Create sidestream from iterator over vector. | |
let v: Vec<i32> = vec![1, 2, 3]; | |
let s = futures_util::stream::iter(v).sidestream(); | |
// Collect sidestream and compare to vector. | |
assert!(s.collect::<Vec<i32>>().await == vec![1, 2, 3]); | |
} | |
#[tokio::test] | |
async fn test_unbounded_sidestream_iter_count() { | |
// Create sidestream with count from iterator over vector. | |
let count = Arc::new(std::sync::atomic::AtomicUsize::new(0)); | |
let v: Vec<i32> = vec![1, 2, 3]; | |
let s = v.into_iter().sidestream_with_count(Some(count.clone())); | |
// Wait for sidestream thread to join. | |
while !s.joined() { | |
tokio::task::yield_now().await; | |
std::sync::atomic::spin_loop_hint(); | |
} | |
// Check count. | |
assert!(s.count() == Some(3)); | |
assert!(count.load(std::sync::atomic::Ordering::Relaxed) == 3); | |
} | |
#[tokio::test] | |
async fn test_unbounded_sidestream_stream_count() { | |
// Create sidestream with count from iterator over vector. | |
let count = Arc::new(std::sync::atomic::AtomicUsize::new(0)); | |
let v: Vec<i32> = vec![1, 2, 3]; | |
let s = futures_util::stream::iter(v).sidestream_with_count(Some(count.clone())); | |
// Wait for sidestream thread to join. | |
while !s.joined() { | |
tokio::task::yield_now().await; | |
std::sync::atomic::spin_loop_hint(); | |
} | |
// Check count. | |
assert!(s.count() == Some(3)); | |
assert!(count.load(std::sync::atomic::Ordering::Relaxed) == 3); | |
} | |
#[tokio::test] | |
async fn test_unbounded_sidestream_iter_size() { | |
// Create sidestream from iterator over vector. | |
let v: Vec<i32> = vec![1, 2, 3]; | |
let mut s = v.into_iter().sidestream(); | |
// Wait for sidestream thread to join. | |
while !s.joined() { | |
tokio::task::yield_now().await; | |
std::sync::atomic::spin_loop_hint(); | |
} | |
// Check size of stream. | |
assert!(s.size_hint() == (3, Some(3))); | |
// Take items from the stream, check size each time. | |
assert!(s.next().await == Some(1)); | |
assert!(s.size_hint() == (2, Some(2))); | |
assert!(s.next().await == Some(2)); | |
assert!(s.size_hint() == (1, Some(1))); | |
assert!(s.next().await == Some(3)); | |
assert!(s.size_hint() == (0, Some(0))); | |
assert!(s.next().await == None); | |
} | |
#[tokio::test] | |
async fn test_unbounded_sidestream_stream_size() { | |
// Create sidestream from iterator over vector. | |
let v: Vec<i32> = vec![1, 2, 3]; | |
let mut s = futures_util::stream::iter(v).sidestream(); | |
// Wait for sidestream thread to join. | |
while !s.joined() { | |
tokio::task::yield_now().await; | |
std::sync::atomic::spin_loop_hint(); | |
} | |
// Check size of stream. | |
assert!(s.size_hint() == (3, Some(3))); | |
// Take items from the stream, check size each time. | |
assert!(s.next().await == Some(1)); | |
assert!(s.size_hint() == (2, Some(2))); | |
assert!(s.next().await == Some(2)); | |
assert!(s.size_hint() == (1, Some(1))); | |
assert!(s.next().await == Some(3)); | |
assert!(s.size_hint() == (0, Some(0))); | |
assert!(s.next().await == None); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment