Skip to content

Instantly share code, notes, and snippets.

@benkay86
Last active February 11, 2021 16:43
Show Gist options
  • Save benkay86/fbfc84babca9b0996d6aee66087e59c4 to your computer and use it in GitHub Desktop.
Save benkay86/fbfc84babca9b0996d6aee66087e59c4 to your computer and use it in GitHub Desktop.
Sidestream collects items from a wrapped stream (or iterator) into a queue on a separate thread.
//! Module for creating sidestreams. A sidestream is a stream over another
//! stream or an iterator. The items from the enclosed stream or iterator are
//! collected into a queue on a separate thread. An optional count parameter is
//! incremented as each item is queued. The items are asynchronously dequeued
//! by the enclosing sidestream. Each item is ready to be dequeued as soon as
//! it is yielded by the enclosed stream or iterator; the sidestream does *not*
//! wait for the collection thread to join. This pattern is useful when you
//! need to know the total number of items in a stream, e.g. to display progress
//! when processing an iterator over a list of files. It is also useful for
//! converting iterators into asynchronous streams.
//!
//! Note that it is safe to drop the sidestream before the collection thread has
//! joined. In this case the collection thread will be gracefully cancelled.
//!
//! See [`SideStreamExtForIterator::sidestream_with_count()`] for iterators.
//! See [`SideStreamExtForStream::sidestream_with_count()`] for streams.
use futures_core::stream::Stream;
use futures_util::stream::StreamExt;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
/// Stream produced by [`sidestream()`](SideStreamExtForStream::sidestream())
/// [_et al_](SideStreamExtForIterator::sidestream()).
pub struct UnboundedSideStream<T> where T: Send {
rx: tokio::sync::mpsc::UnboundedReceiver<T>,
// Incremented each time message is sent over the channel.
count: Option<Arc<AtomicUsize>>,
// Incremented each time message is sent over channel and decremented each
// time item is dequeued from channel, represents number of items left in
// stream.
size: Arc<AtomicUsize>,
// Set to true when thread holding sending side of channel is joined.
joined: Arc<AtomicBool>
}
impl<T> UnboundedSideStream<T> where T: Send {
/// True if the collection thread for this sidestream has joined.
pub fn joined(&self) -> bool {
self.joined.load(Ordering::Relaxed)
}
/// Current value of the count, if any.
pub fn count(&self) -> Option<usize> {
match &self.count {
Some(count) => Some(count.load(Ordering::Relaxed)),
None => None
}
}
}
impl<T> Stream for UnboundedSideStream<T> where T: Send {
type Item = T;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
// Get mutable reference to self.
let this = self.get_mut();
// Poll inner receiver.
match this.rx.poll_recv(cx) {
// Pass through pending value.
Poll::Pending => Poll::Pending,
// Decrement size if value is ready.
std::task::Poll::Ready(val) => {
this.size.fetch_sub(1, Ordering::Release);
std::task::Poll::Ready(val)
}
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
match self.joined.load(Ordering::Relaxed) {
// Sending thread is joined, so size of stream will not increase.
// Upper bound Some(size) is therefore known.
true => {
let size = self.size.load(Ordering::Acquire);
(size, Some(size))
},
// Size of stream could increase, so upper bound None is unknown.
false => (self.size.load(Ordering::Acquire), None)
}
}
}
/// Combinator trait to convert iterators into sidestreams.
pub trait SideStreamExtForIterator {
type Item: Send + 'static;
/// Convert an iterator into an unbounded sidestream. If a count variable
/// is provided then it will be atomically incremented each time an item is
/// collected from the iterator.
fn sidestream_with_count(self, count: Option<Arc<AtomicUsize>>) -> UnboundedSideStream<Self::Item>;
/// Shortcut for calling `sidestream_with_count(None)`.
/// See [`sidestream_with_count()`](SideStreamExtForIterator::sidestream_with_count()).
fn sidestream(self) -> UnboundedSideStream<Self::Item> where Self: Sized {
self.sidestream_with_count(None)
}
}
impl<S, T> SideStreamExtForIterator for S
where
S: IntoIterator<Item = T> + Send + 'static,
T: Send + 'static
{
type Item = T;
fn sidestream_with_count(self, count: Option<Arc<AtomicUsize>>) -> UnboundedSideStream<Self::Item> {
// Make the size and joined variables.
let size = Arc::new(AtomicUsize::new(0));
let joined = Arc::new(AtomicBool::new(false));
// Make the channel.
let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<Self::Item>();
// Spawn a thread on the send side of the channel.
{
// Clone atomic variables to be shared with thread.
let count = count.clone();
let size = size.clone();
let joined = joined.clone();
// Spawn thread.
tokio::task::spawn_blocking(move || {
// Traverse the iterator.
for item in self {
// Send item over the channel.
if tx.send(item).is_err() {
// Cancel this thread if the stream was cancelled.
break;
}
// Increment size.
size.fetch_add(1, Ordering::Release);
// Increment count.
if let Some(count) = &count {
count.fetch_add(1, Ordering::Relaxed);
}
}
// Iterator is exhausted. About to join this thread.
joined.store(true, Ordering::Relaxed);
});
}
// Return UnboundedSideStream holding the receive side of the channel.
UnboundedSideStream { rx, count, size, joined }
}
}
/// Combinator trait to convert streams into sidestreams.
pub trait SideStreamExtForStream {
type Item: Send + 'static;
/// Convert a stream into an unbounded sidestream. If a count variable
/// is provided then it will be atomically incremented each time an item
/// is collected from the iterator.
fn sidestream_with_count(self, count: Option<Arc<AtomicUsize>>) -> UnboundedSideStream<Self::Item>;
/// Shortcut for calling `sidestream_with_count(None)`.
/// See [`sidestream_with_count()`](SideStreamExtForStream::sidestream_with_count()).
fn sidestream(self) -> UnboundedSideStream<Self::Item> where Self: Sized {
self.sidestream_with_count(None)
}
}
impl<S, T> SideStreamExtForStream for S
where
S: Stream<Item = T> + Send + 'static,
T: Send + 'static
{
type Item = T;
fn sidestream_with_count(self, count: Option<Arc<AtomicUsize>>) -> UnboundedSideStream<Self::Item> {
// Make the size and joined variables.
let size = Arc::new(AtomicUsize::new(0));
let joined = Arc::new(AtomicBool::new(false));
// Make the channel.
let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<Self::Item>();
// Spawn a thread on the send side of the channel.
{
// Clone atomic variables to be shared with thread.
let count = count.clone();
let size = size.clone();
let joined = joined.clone();
// Spawn thread.
tokio::task::spawn(async move {
// Pin stream to this thread's stack.
let s = self;
tokio::pin!(s);
// Iterate over the elements of the stream.
while let Some(item) = (s.next()).await {
// Send item over the channel.
if tx.send(item).is_err() {
// Cancel this thread if the stream was cancelled.
break;
}
// Increment size.
size.fetch_add(1, Ordering::Release);
// Increment count.
if let Some(count) = &count {
count.fetch_add(1, Ordering::Relaxed);
}
}
// Iterator is exhausted. About to join this thread.
joined.store(true, Ordering::Relaxed);
});
}
// Return UnboundedSideStream holding the receive side of the channel.
UnboundedSideStream { rx, count, size, joined }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_unbounded_sidestream_iter() {
// Create sidestream from iterator over vector.
let v: Vec<i32> = vec![1, 2, 3];
let s = v.into_iter().sidestream();
// Collect sidestream and compare to vector.
assert!(s.collect::<Vec<i32>>().await == vec![1, 2, 3]);
}
#[tokio::test]
async fn test_unbounded_sidestream_stream() {
// Create sidestream from iterator over vector.
let v: Vec<i32> = vec![1, 2, 3];
let s = futures_util::stream::iter(v).sidestream();
// Collect sidestream and compare to vector.
assert!(s.collect::<Vec<i32>>().await == vec![1, 2, 3]);
}
#[tokio::test]
async fn test_unbounded_sidestream_iter_count() {
// Create sidestream with count from iterator over vector.
let count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let v: Vec<i32> = vec![1, 2, 3];
let s = v.into_iter().sidestream_with_count(Some(count.clone()));
// Wait for sidestream thread to join.
while !s.joined() {
tokio::task::yield_now().await;
std::sync::atomic::spin_loop_hint();
}
// Check count.
assert!(s.count() == Some(3));
assert!(count.load(std::sync::atomic::Ordering::Relaxed) == 3);
}
#[tokio::test]
async fn test_unbounded_sidestream_stream_count() {
// Create sidestream with count from iterator over vector.
let count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let v: Vec<i32> = vec![1, 2, 3];
let s = futures_util::stream::iter(v).sidestream_with_count(Some(count.clone()));
// Wait for sidestream thread to join.
while !s.joined() {
tokio::task::yield_now().await;
std::sync::atomic::spin_loop_hint();
}
// Check count.
assert!(s.count() == Some(3));
assert!(count.load(std::sync::atomic::Ordering::Relaxed) == 3);
}
#[tokio::test]
async fn test_unbounded_sidestream_iter_size() {
// Create sidestream from iterator over vector.
let v: Vec<i32> = vec![1, 2, 3];
let mut s = v.into_iter().sidestream();
// Wait for sidestream thread to join.
while !s.joined() {
tokio::task::yield_now().await;
std::sync::atomic::spin_loop_hint();
}
// Check size of stream.
assert!(s.size_hint() == (3, Some(3)));
// Take items from the stream, check size each time.
assert!(s.next().await == Some(1));
assert!(s.size_hint() == (2, Some(2)));
assert!(s.next().await == Some(2));
assert!(s.size_hint() == (1, Some(1)));
assert!(s.next().await == Some(3));
assert!(s.size_hint() == (0, Some(0)));
assert!(s.next().await == None);
}
#[tokio::test]
async fn test_unbounded_sidestream_stream_size() {
// Create sidestream from iterator over vector.
let v: Vec<i32> = vec![1, 2, 3];
let mut s = futures_util::stream::iter(v).sidestream();
// Wait for sidestream thread to join.
while !s.joined() {
tokio::task::yield_now().await;
std::sync::atomic::spin_loop_hint();
}
// Check size of stream.
assert!(s.size_hint() == (3, Some(3)));
// Take items from the stream, check size each time.
assert!(s.next().await == Some(1));
assert!(s.size_hint() == (2, Some(2)));
assert!(s.next().await == Some(2));
assert!(s.size_hint() == (1, Some(1)));
assert!(s.next().await == Some(3));
assert!(s.size_hint() == (0, Some(0)));
assert!(s.next().await == None);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment