Skip to content

Instantly share code, notes, and snippets.

@metasim
Last active October 24, 2024 20:00
Show Gist options
  • Save metasim/82de345ea1b1746b996aac369e7289b4 to your computer and use it in GitHub Desktop.
Save metasim/82de345ea1b1746b996aac369e7289b4 to your computer and use it in GitHub Desktop.
use futuresdr::anyhow::Result;
use futuresdr::macros::async_trait;
use futuresdr::runtime::BlockMeta;
use futuresdr::runtime::BlockMetaBuilder;
use futuresdr::runtime::Kernel;
use futuresdr::runtime::MessageIo;
use futuresdr::runtime::MessageIoBuilder;
use futuresdr::runtime::StreamIo;
use futuresdr::runtime::StreamIoBuilder;
use futuresdr::runtime::TypedBlock;
use futuresdr::runtime::WorkIo;
/// Reads chunks of size `WIDTH` and outputs an exponential moving average over a window of specified size.
pub struct MovingAvg<const WIDTH: usize> {
decay_factor: f32,
history_size: usize,
i: usize,
avg: [f32; WIDTH],
}
impl<const WIDTH: usize> MovingAvg<WIDTH> {
/// Instantiate moving average block.
///
/// # Arguments
///
/// * `decay_factor`: amount current value should contribute to the rolling average.
/// Must be in `[0.0, 1.0]`.
/// * `history_size`: number of chunks to average over
///
/// Typical parameter values might be `decay_factor=0.1` and `history_size=3`
pub fn new_typed(decay_factor: f32, history_size: usize) -> TypedBlock<Self> {
assert!(
(0.0..=1.0).contains(&decay_factor),
"decay_factor must be in [0, 1]"
);
TypedBlock::new(
BlockMetaBuilder::new("WindowedDecay").build(),
StreamIoBuilder::new()
.add_input::<f32>("in")
.add_output::<f32>("out")
.build(),
MessageIoBuilder::new().build(),
Self {
decay_factor,
history_size,
i: 0,
avg: [0.0; WIDTH],
},
)
}
}
#[async_trait]
impl<const N: usize> Kernel for MovingAvg<N> {
async fn work(
&mut self,
io: &mut WorkIo,
sio: &mut StreamIo,
_mio: &mut MessageIo<Self>,
_meta: &mut BlockMeta,
) -> Result<()> {
let input = sio.input(0).slice::<f32>();
let output = sio.output(0).slice::<f32>();
let mut consumed = 0;
let mut produced = 0;
while (consumed + 1) * N <= input.len() {
for i in 0..N {
let t = input[consumed * N + i];
if t.is_finite() {
self.avg[i] = (1.0 - self.decay_factor) * self.avg[i] + self.decay_factor * t;
} else {
self.avg[i] *= 1.0 - self.decay_factor;
}
}
self.i += 1;
if self.i == self.history_size {
if (produced + 1) * N <= output.len() {
output[produced * N..(produced + 1) * N].clone_from_slice(&self.avg);
self.i = 0;
produced += 1;
} else {
break;
}
}
consumed += 1;
}
if sio.input(0).finished() && consumed == input.len() / N {
io.finished = true;
}
sio.input(0).consume(consumed * N);
sio.output(0).produce(produced * N);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use futuresdr::runtime::Mocker;
#[test]
fn moving_avg_correct_output() {
let block = MovingAvg::<3>::new_typed(0.1, 3);
let mut mocker = Mocker::new(block);
mocker.input::<f32>(0, vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
mocker.init_output::<f32>(0, 3);
mocker.run();
assert_eq!(mocker.output::<f32>(0), vec![0.271, 0.542, 0.813]);
}
#[test]
fn moving_avg_handles_non_finite_values() {
let block = MovingAvg::<3>::new_typed(0.1, 3);
let mut mocker = Mocker::new(block);
mocker.input::<f32>(
0,
vec![1.0, f32::NAN, 3.0, 1.0, f32::INFINITY, 3.0, 1.0, 2.0, 3.0],
);
mocker.init_output::<f32>(0, 3);
mocker.run();
assert_eq!(mocker.output::<f32>(0), vec![0.271, 0.2, 0.813]);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment