Created
August 1, 2023 13:46
-
-
Save Qqwy/b2d5d17b5e459e5e29621cf598b57aad to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use backdrop::BackdropStrategy; | |
use stable_deref_trait::StableDeref; | |
use std::sync::Arc; | |
use streaming_iterator::StreamingIterator; | |
/// A struct that wraps any `impl StreamingIterator<item = T>` | |
/// together with its data source (where it borrows the references to T from). | |
/// | |
/// The data source will be dropped once iteration is finished. | |
/// To make this safe, the data source will need to be in a stable location in memory. | |
/// In other words: Wrapped in a container that implements `StableDeref`. | |
/// Most frequently, this means behind a `Rc` or `Arc`. | |
pub struct OwningStreamingIter<T, Source, Iter, StableSourceContainer> | |
where | |
T: 'static, | |
Iter: StreamingIterator<Item = T>, | |
StableSourceContainer: StableDeref<Target = Source>, | |
{ | |
// Actually has lifetime of '_source' | |
// declared first so it is dropped first | |
iter: Iter, | |
_source: StableSourceContainer, | |
} | |
impl<T, Source, Iter> OwningStreamingIter<T, Source, Iter, Arc<Source>> | |
where | |
T: 'static, | |
Source: 'static, | |
Iter: StreamingIterator<Item = T>, | |
{ | |
/// Creates a new OwningStreamingIter | |
/// | |
/// # SAFETY | |
/// | |
/// Do not store the parameter given to the `iter_creation_fn` callback elsewhere! | |
/// The reference will live as long as this OwnedStreamingIter object, | |
/// (And this is what makes it OK to return a StreamingIterator for anything referencing any part of it) | |
/// but it is not truly static! | |
/// | |
/// (Pretty much the only way to do this wrong, is to deliberately write some of its contents a datatype with interior mutability (RefCell, Mutex). But I am not aware of any way to write out the function signature to convince Rust that the &'a Source lives longer than Iter without resorting to &'static. And therefore this function has to be unsafe.) | |
pub unsafe fn new<'a>( | |
source: Arc<Source>, | |
iter_creation_fn: impl Fn(&'static Source) -> Iter, | |
) -> Self { | |
let iter = { | |
// SAFETY: Pretend like 'source' is static for a little while. | |
// This is correct, as from the viewpoint of 'iter' it will behave identically: | |
// - As long as 'source' is around, any references into it will remain valid as its memory location is fixed (courtesy of StableDeref). | |
// - 'source' will be around for longer than 'iter': it is dropped later because of struct field ordering. | |
let static_source_ref: &'static Source = | |
unsafe { std::mem::transmute::<&'_ Source, &'static Source>(&*source) }; | |
iter_creation_fn(static_source_ref) | |
}; | |
OwningStreamingIter { | |
iter, | |
_source: source, | |
} | |
} | |
} | |
impl<T, Source, Iter, Strat> OwningStreamingIter<T, Source, Iter, backdrop_arc::Arc<Source, Strat>> | |
where | |
T: 'static, | |
Source: 'static, | |
Iter: StreamingIterator<Item = T>, | |
Strat: BackdropStrategy<Box<Source>>, | |
{ | |
pub unsafe fn new_backdrop<'a>( | |
source: backdrop_arc::Arc<Source, Strat>, | |
iter_creation_fn: impl Fn(&'static Source) -> Iter, | |
) -> Self { | |
let iter = { | |
// SAFETY: Pretend like 'source' is static for a little while. | |
// This is correct, as from the viewpoint of 'iter' it will behave identically: | |
// - As long as 'source' is around, any references into it will remain valid as its memory location is fixed (courtesy of StableDeref). | |
// - 'source' will be around for longer than 'iter': it is dropped later because of struct field ordering. | |
let static_source_ref: &'static Source = | |
unsafe { std::mem::transmute::<&'_ Source, &'static Source>(&*source) }; | |
iter_creation_fn(static_source_ref) | |
}; | |
OwningStreamingIter { | |
iter, | |
_source: source, | |
} | |
} | |
} | |
pub struct OnceRef<'a, T> { | |
first: bool, | |
item: Option<&'a T>, | |
} | |
impl<'a, T> StreamingIterator for OnceRef<'a, T> { | |
type Item = T; | |
#[inline] | |
fn advance(&mut self) { | |
if self.first { | |
self.first = false; | |
} else { | |
self.item = None; | |
} | |
} | |
#[inline] | |
fn get(&self) -> Option<&Self::Item> { | |
self.item | |
} | |
#[inline] | |
fn size_hint(&self) -> (usize, Option<usize>) { | |
let len = self.first as usize; | |
(len, Some(len)) | |
} | |
} | |
impl<'a, T> streaming_iterator::DoubleEndedStreamingIterator for OnceRef<'a, T> { | |
#[inline] | |
fn advance_back(&mut self) { | |
self.advance(); | |
} | |
} | |
/// Creates an iterator that returns exactly one item from a reference. | |
/// | |
/// ``` | |
/// # use streaming_iterator::StreamingIterator; | |
/// let mut streaming_iter = streaming_iterator::once(1); | |
/// assert_eq!(streaming_iter.next(), Some(&1)); | |
/// assert_eq!(streaming_iter.next(), None); | |
/// ``` | |
#[inline] | |
pub fn once_ref<'a, T>(item: &'a T) -> OnceRef<'a, T> { | |
OnceRef { | |
first: true, | |
item: Some(item), | |
} | |
} | |
impl<'a, Source, StableSourceContainer> | |
OwningStreamingIter<Source, Source, OnceRef<'a, Source>, StableSourceContainer> | |
where | |
Source: 'a, | |
StableSourceContainer: StableDeref<Target = Source>, | |
{ | |
pub fn new2(source: StableSourceContainer) -> Self { | |
let iter = { | |
let long_source_ref: &'a Source = | |
unsafe { std::mem::transmute::<&'_ Source, &'a Source>(&*source) }; | |
let iter = once_ref(long_source_ref); | |
iter | |
}; | |
Self { | |
iter, | |
_source: source, | |
} | |
} | |
} | |
// impl<T, Source, Iter, StableSourceContainer> Drop | |
// for OwningStreamingIter<T, Source, Iter, StableSourceContainer> | |
// where | |
// Iter: StreamingIterator<Item = T>, | |
// StableSourceContainer: StableDeref<Target = Source>, | |
// { | |
// fn drop(&mut self) { | |
// println!( | |
// "Dropping OwningIterator {:?}", | |
// &self._source as *const StableSourceContainer | |
// ); | |
// } | |
// } | |
impl<T, Source, Iter, StableSourceContainer> StreamingIterator | |
for OwningStreamingIter<T, Source, Iter, StableSourceContainer> | |
where | |
Iter: StreamingIterator<Item = T>, | |
StableSourceContainer: StableDeref<Target = Source>, | |
{ | |
type Item = T; | |
fn get(&self) -> Option<&Self::Item> { | |
self.iter.get() | |
} | |
fn advance(&mut self) { | |
// println!( | |
// "Advancing OwningIterator {:?}", | |
// &self._source as *const StableSourceContainer | |
// ); | |
self.iter.advance() | |
} | |
} | |
/// Helper struct to help the compiler figure out | |
/// that a Box<dyn StreamingIterator> can be used as a StreamingIterator | |
/// | |
/// Useful for constructing recursive (streaming) iterators. | |
pub struct BoxedSI<'storage, K>(Box<dyn StreamingIterator<Item = K> + 'storage + Send>); | |
impl<'storage, K> BoxedSI<'storage, K> { | |
pub fn new(iterator: impl StreamingIterator<Item = K> + 'storage + Send) -> Self { | |
BoxedSI(Box::from(iterator)) | |
} | |
pub fn from_box(boxed_iter: Box<dyn StreamingIterator<Item = K> + 'storage + Send>) -> Self { | |
BoxedSI(boxed_iter) | |
} | |
} | |
impl<'storage, K> StreamingIterator for BoxedSI<'storage, K> { | |
type Item = K; | |
fn get(&self) -> Option<&Self::Item> { | |
self.0.get() | |
} | |
fn advance(&mut self) { | |
self.0.advance() | |
} | |
} | |
/// When calling `.map` or `.flat_map` etc. on an OwningStreamingIter, | |
/// Rust will often tell you that the return type captures a reference ('2) from the passed input reference ('1) | |
/// and that this is a problem since there is no way to know whether '1 will outlive '2. | |
/// | |
/// ```ignore | |
/// error: lifetime may not live long enough | |
/// --> src/foo.rs:99:13 | |
/// | | |
/// 93 | my_owning_streaming_iter.map(|node| match node { | |
/// | ----- return type of closure is std::slice::Iter<'2, T> | |
/// | | | |
/// | has type `&'1 Node<T>` | |
/// ... | |
/// 99 | convert_ref(children.iter()) | |
/// | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ returning this value requires that `'1` must outlive `'2` | |
/// | |
/// ``` | |
/// | |
/// However, the whole point of using OwnedStreamingIter is to make sure that '1 *does* outlive '2! | |
/// | |
/// Unfortunately, there is no `'this` lifetime or the likes in current Rust to indicate that the reference will be around | |
/// exactly as long as the `OwnedStreamingIter` object itself. | |
/// | |
/// This function can be used to tell the compiler that the reference that is passed to the function indeed will live long enough. | |
/// However, using it is unsafe, as misuse leads to dangling references. | |
/// | |
/// # Safety | |
/// | |
/// ## 1 Only use this function to extend the lifetime of the *passed-in* reference. | |
/// | |
/// Using it to extend the lifetime of any other reference (like a reference *returned* from a mapping operation) | |
/// can very easily lead to accidental dangling references being created in safe Rust code later on. | |
/// (Such as in a later iterator-transforming operation). | |
/// | |
/// ## 2 Only use the extended reference to construct the outcome of the mapping function from | |
/// | |
/// Specifically, do not store (any subpart of) the extended reference inside any datastructure outside of the closure. | |
/// This goes for both `mut` variables and for types with interior mutability (RefCell, Mutex, etc.) | |
/// as they might outlive the lifetime of the OwningStreamingIter object and the compiler will not catch this! | |
/// | |
/// The easiest way to uphold this is by just keeping the transformation function pure. | |
/// Upholding this guarantee is not that difficult, but keep in mind that logging then-and-there is fine, but logging lazily needs to be done very carefully. | |
pub unsafe fn extend_lifetime<'long, T>(short_ref: &T) -> &'long T { | |
unsafe { std::mem::transmute(short_ref) } | |
} | |
#[cfg(test)] | |
pub mod tests { | |
use super::*; | |
#[test] | |
pub fn iter_vec() { | |
let vec: Arc<Vec<usize>> = (1..1000).collect::<Vec<_>>().into(); | |
let owning_iterator = unsafe { | |
OwningStreamingIter::new(vec.clone(), |vec| { | |
streaming_iterator::convert_ref(vec.iter()) | |
}) | |
}; | |
let result: Vec<_> = owning_iterator.cloned().collect(); | |
assert_eq!(&result, &*vec); | |
} | |
#[test] | |
pub fn iter_vec2() { | |
let vec: Arc<Vec<usize>> = (1..1000).collect::<Vec<_>>().into(); | |
let owning_iterator = { | |
OwningStreamingIter::new2(vec.clone()).flat_map(|vec| { | |
// SAFETY: Used on the passed-in reference, and closure is pure | |
let vec = unsafe { extend_lifetime(vec) }; | |
streaming_iterator::convert_ref(vec.iter()) | |
}) | |
}; | |
let result: Vec<_> = owning_iterator.cloned().collect(); | |
dbg!(&result); | |
assert_eq!(&result, &*vec); | |
} | |
// Breaks the unsafe contract | |
// MIRI catches this :-) | |
// #[test] | |
pub fn disallow_copying_data_outside_of_initialization_function() { | |
use std::sync::Mutex; | |
let data: Mutex<Option<&Vec<usize>>> = None.into(); | |
{ | |
let vec: Arc<Vec<usize>> = (1..1000).collect::<Vec<_>>().into(); | |
let owning_iterator = | |
// SAFETY BROKEN: Deliberately breaking the safety contract here | |
// by writing to 'data' inside the callback | |
unsafe { | |
OwningStreamingIter::new(vec.clone(), |vec| { | |
*data.lock().unwrap() = Some(vec); | |
streaming_iterator::convert_ref(vec.iter()) | |
}) | |
}; | |
let result: Vec<_> = owning_iterator.cloned().collect(); | |
assert_eq!(&result, &*vec); | |
} | |
assert_eq!(None, *data.lock().unwrap()); | |
} | |
// Breaks the unsafe contract | |
// MIRI catches this :-) | |
#[test] | |
pub fn disallow_copying_data_outside_of_initialization_function2() { | |
// use std::sync::Mutex; | |
// let data: Mutex<Option<&Vec<usize>>> = None.into(); | |
{ | |
let vec: Arc<Vec<usize>> = (1..1000).collect::<Vec<_>>().into(); | |
let owning_iterator = { | |
OwningStreamingIter::new2(vec.clone()).flat_map(|vec| { | |
// SAFETY: Used on the passed-in reference, and closure is pure | |
let vec = unsafe { extend_lifetime(vec) }; | |
streaming_iterator::convert_ref(vec.iter()) | |
}) | |
}; | |
let result: Vec<_> = owning_iterator.cloned().collect(); | |
assert_eq!(&result, &*vec); | |
} | |
// assert_eq!(None, *data.lock().unwrap()); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment