Last active
January 30, 2020 07:18
-
-
Save sdbondi/cbc43752faa75914b36ae9dd3fd13915 to your computer and use it in GitHub Desktop.
Generic retry stream
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use futures::{ready, stream::FusedStream, task::Context, Future, Stream}; | |
use pin_project::{pin_project, project}; | |
use std::{pin::Pin, task::Poll}; | |
use tokio::time::{delay_for, Delay}; | |
pub trait Backoff { | |
fn calculate_backoff(&self, attempts: usize) -> Duration; | |
} | |
#[pin_project] | |
enum State<TFut> { | |
Initial, | |
Waiting(#[pin] Delay), | |
Running(#[pin] TFut), | |
Complete, | |
} | |
/// Future which tries to run another future a few times until it succeeds or the maximum attempts is reached | |
#[pin_project] | |
pub struct DelayedRetry<'a, TFutFactory, TFut, TBackoff> { | |
future_factory: TFutFactory, | |
backoff: TBackoff, | |
attempts: usize, | |
max_attempts: usize, | |
#[pin] | |
state: State<TFut>, | |
_lifetime: PhantomData<&'a ()>, | |
} | |
impl<'a, TFutFactory, TFut, TBackoff, T, E> DelayedRetry<'a, TFutFactory, TFut, TBackoff> | |
where | |
TFutFactory: FnMut(usize) -> TFut, | |
TFut: Future<Output = Result<T, E>>, | |
TBackoff: Backoff, | |
{ | |
pub fn new(future_factory: TFutFactory, backoff: TBackoff, max_attempts: usize) -> Self { | |
Self { | |
future_factory, | |
backoff, | |
max_attempts, | |
attempts: 0, | |
state: State::Initial, | |
_lifetime: PhantomData, | |
} | |
} | |
} | |
impl<'a, TFutFactory, TFut, TBackoff, T, E> Stream for DelayedRetry<'a, TFutFactory, TFut, TBackoff> | |
where | |
TFutFactory: FnMut(usize) -> TFut, | |
TFut: Future<Output = Result<T, E>> + 'a, | |
TBackoff: Backoff, | |
{ | |
type Item = Result<T, E>; | |
#[project] | |
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { | |
let mut this = self.project(); | |
let mut current_result = None; | |
loop { | |
#[project] | |
let next_state = match this.state.as_mut().project() { | |
State::Initial => { | |
let backoff_time = this.backoff.calculate_backoff(*this.attempts); | |
if backoff_time.as_micros() > 0 { | |
State::Waiting(delay_for(backoff_time)) | |
} else { | |
State::Running((this.future_factory)(1)) | |
} | |
}, | |
State::Waiting(delay_fut) => { | |
ready!(delay_fut.poll(cx)); | |
State::Running((this.future_factory)(*this.attempts)) | |
}, | |
State::Running(fut) => match ready!(fut.poll(cx)) { | |
Ok(v) => { | |
current_result = Some(Ok(v)); | |
State::Complete | |
}, | |
Err(err) => { | |
current_result = Some(Err(err)); | |
*this.attempts += 1; | |
if this.attempts >= this.max_attempts { | |
// After we emit the final error, the stream ends because we've already reached the maximum | |
// attempts | |
State::Complete | |
} else { | |
let backoff_time = this.backoff.calculate_backoff(*this.attempts); | |
State::Waiting(delay_for(backoff_time)) | |
} | |
}, | |
}, | |
State::Complete => { | |
return Poll::Ready(None); | |
}, | |
}; | |
this.state.set(next_state); | |
if let Some(result) = current_result.take() { | |
return Poll::Ready(Some(result)); | |
} | |
} | |
} | |
} | |
impl<T, E, TFutFactory, TFut, TBackoff> FusedStream for DelayedRetry<'_, TFutFactory, TFut, TBackoff> | |
where | |
TFutFactory: FnMut(usize) -> TFut, | |
TFut: Future<Output = Result<T, E>>, | |
TBackoff: Backoff, | |
{ | |
fn is_terminated(&self) -> bool { | |
match self.state { | |
State::Complete => true, | |
_ => false, | |
} | |
} | |
} | |
#[cfg(test)] | |
mod test { | |
use super::*; | |
use crate::backoff::ConstantBackoff; | |
use futures::{future, StreamExt}; | |
use std::{ | |
sync::{ | |
atomic::{AtomicUsize, Ordering}, | |
Arc, | |
}, | |
time::Duration, | |
}; | |
#[tokio_macros::test_basic] | |
async fn never_succeeds() { | |
let call_count = Arc::new(AtomicUsize::new(0)); | |
let call_count_clone = call_count.clone(); | |
let retry = DelayedRetry::new( | |
|_| { | |
call_count_clone.fetch_add(1, Ordering::Relaxed); | |
future::ready(Result::<(), _>::Err(())) | |
}, | |
ConstantBackoff::new(Duration::from_millis(1)), | |
3, | |
); | |
let results = retry.collect::<Vec<_>>().await; | |
// 3 results emitted | |
assert_eq!(results.len(), 3); | |
// ... all of them errors | |
assert_eq!(results.into_iter().filter(Result::is_err).count(), 3); | |
// ... from exactly 3 attempts | |
assert_eq!(call_count.load(Ordering::Relaxed), 3); | |
} | |
#[tokio_macros::test_basic] | |
async fn succeeds_later() { | |
let call_count = Arc::new(AtomicUsize::new(0)); | |
let call_count_clone = call_count.clone(); | |
let retry = DelayedRetry::new( | |
|_| match call_count_clone.fetch_add(1, Ordering::Relaxed) { | |
2 => future::ready(Ok("Works!")), | |
_ => future::ready(Err(())), | |
}, | |
ConstantBackoff::new(Duration::from_millis(1)), | |
3, | |
); | |
let results = retry.collect::<Vec<_>>().await; | |
// 3 Results emitted | |
assert_eq!(results.len(), 3); | |
// ... 2 errors | |
assert_eq!(results.iter().filter(|r| r.is_err()).count(), 2); | |
// ... ending with a success | |
assert_eq!(results.last().unwrap(), &Ok("Works!")); | |
// ... after exactly 3 attempts | |
assert_eq!(call_count.load(Ordering::Relaxed), 3); | |
} | |
#[tokio_macros::test_basic] | |
async fn succeeds_immediately() { | |
let call_count = Arc::new(AtomicUsize::new(0)); | |
let call_count_clone = call_count.clone(); | |
let retry = DelayedRetry::new( | |
|_| { | |
call_count_clone.fetch_add(1, Ordering::Relaxed); | |
future::ready(Result::<_, ()>::Ok("Works!")) | |
}, | |
ConstantBackoff::new(Duration::from_millis(1)), | |
3, | |
); | |
let results = retry.collect::<Vec<_>>().await; | |
assert_eq!(results.len(), 1); | |
// Returns the last error i.e the third call returns call_count == 2 | |
assert_eq!(results.get(0).unwrap(), &Ok("Works!")); | |
// ... after exactly 1 attempt | |
assert_eq!(call_count.load(Ordering::Relaxed), 1); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment