Last active
October 24, 2024 20:00
-
-
Save metasim/82de345ea1b1746b996aac369e7289b4 to your computer and use it in GitHub Desktop.
Reformulation of https://github.com/FutureSDR/FutureSDR/blob/main/examples/egui/src/keep_1_in_n.rs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use futuresdr::anyhow::Result; | |
use futuresdr::macros::async_trait; | |
use futuresdr::runtime::BlockMeta; | |
use futuresdr::runtime::BlockMetaBuilder; | |
use futuresdr::runtime::Kernel; | |
use futuresdr::runtime::MessageIo; | |
use futuresdr::runtime::MessageIoBuilder; | |
use futuresdr::runtime::StreamIo; | |
use futuresdr::runtime::StreamIoBuilder; | |
use futuresdr::runtime::TypedBlock; | |
use futuresdr::runtime::WorkIo; | |
/// Reads chunks of size `WIDTH` and outputs an exponential moving average over a window of specified size. | |
pub struct MovingAvg<const WIDTH: usize> { | |
decay_factor: f32, | |
history_size: usize, | |
i: usize, | |
avg: [f32; WIDTH], | |
} | |
impl<const WIDTH: usize> MovingAvg<WIDTH> { | |
/// Instantiate moving average block. | |
/// | |
/// # Arguments | |
/// | |
/// * `decay_factor`: amount current value should contribute to the rolling average. | |
/// Must be in `[0.0, 1.0]`. | |
/// * `history_size`: number of chunks to average over | |
/// | |
/// Typical parameter values might be `decay_factor=0.1` and `history_size=3` | |
pub fn new_typed(decay_factor: f32, history_size: usize) -> TypedBlock<Self> { | |
assert!( | |
(0.0..=1.0).contains(&decay_factor), | |
"decay_factor must be in [0, 1]" | |
); | |
TypedBlock::new( | |
BlockMetaBuilder::new("WindowedDecay").build(), | |
StreamIoBuilder::new() | |
.add_input::<f32>("in") | |
.add_output::<f32>("out") | |
.build(), | |
MessageIoBuilder::new().build(), | |
Self { | |
decay_factor, | |
history_size, | |
i: 0, | |
avg: [0.0; WIDTH], | |
}, | |
) | |
} | |
} | |
#[async_trait] | |
impl<const N: usize> Kernel for MovingAvg<N> { | |
async fn work( | |
&mut self, | |
io: &mut WorkIo, | |
sio: &mut StreamIo, | |
_mio: &mut MessageIo<Self>, | |
_meta: &mut BlockMeta, | |
) -> Result<()> { | |
let input = sio.input(0).slice::<f32>(); | |
let output = sio.output(0).slice::<f32>(); | |
let mut consumed = 0; | |
let mut produced = 0; | |
while (consumed + 1) * N <= input.len() { | |
for i in 0..N { | |
let t = input[consumed * N + i]; | |
if t.is_finite() { | |
self.avg[i] = (1.0 - self.decay_factor) * self.avg[i] + self.decay_factor * t; | |
} else { | |
self.avg[i] *= 1.0 - self.decay_factor; | |
} | |
} | |
self.i += 1; | |
if self.i == self.history_size { | |
if (produced + 1) * N <= output.len() { | |
output[produced * N..(produced + 1) * N].clone_from_slice(&self.avg); | |
self.i = 0; | |
produced += 1; | |
} else { | |
break; | |
} | |
} | |
consumed += 1; | |
} | |
if sio.input(0).finished() && consumed == input.len() / N { | |
io.finished = true; | |
} | |
sio.input(0).consume(consumed * N); | |
sio.output(0).produce(produced * N); | |
Ok(()) | |
} | |
} | |
#[cfg(test)] | |
mod tests { | |
use super::*; | |
use futuresdr::runtime::Mocker; | |
#[test] | |
fn moving_avg_correct_output() { | |
let block = MovingAvg::<3>::new_typed(0.1, 3); | |
let mut mocker = Mocker::new(block); | |
mocker.input::<f32>(0, vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0]); | |
mocker.init_output::<f32>(0, 3); | |
mocker.run(); | |
assert_eq!(mocker.output::<f32>(0), vec![0.271, 0.542, 0.813]); | |
} | |
#[test] | |
fn moving_avg_handles_non_finite_values() { | |
let block = MovingAvg::<3>::new_typed(0.1, 3); | |
let mut mocker = Mocker::new(block); | |
mocker.input::<f32>( | |
0, | |
vec![1.0, f32::NAN, 3.0, 1.0, f32::INFINITY, 3.0, 1.0, 2.0, 3.0], | |
); | |
mocker.init_output::<f32>(0, 3); | |
mocker.run(); | |
assert_eq!(mocker.output::<f32>(0), vec![0.271, 0.2, 0.813]); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment