Created
November 24, 2022 20:15
-
-
Save ebrahimebrahim/4e48b17c302bfa4f6c221bd27ffa1b09 to your computer and use it in GitHub Desktop.
running mean with numpy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from numpy.typing import NDArray | |
class RunningMean: | |
"""Like numpy.mean but aggregates data in chunks at a time.""" | |
def __init__(self, axis=None, keepdims=False): | |
"""For axis and keepdims, see the documentation of these arg names in numpy.mean""" | |
self.axis=axis | |
self.keepdims=keepdims | |
self.current_mean = None | |
self.current_n = 0 | |
def submit_data(self, data : NDArray): | |
""" | |
Add a data chunk to the mean. | |
Shape should agree with data previously added, except possibly at the axes over which the mean is taking place. | |
""" | |
num_items_being_added_to_mean = self._count_items_being_added_to_mean(data) | |
if self.current_mean is None: | |
self.current_mean = data.mean(axis=self.axis, keepdims=self.keepdims) | |
else: | |
self.current_mean = (self.current_n * self.current_mean + data.sum(axis=self.axis, keepdims=self.keepdims)) / (self.current_n + num_items_being_added_to_mean) | |
self.current_n += num_items_being_added_to_mean | |
def get(self): | |
"""Return current running mean. Returns none if no data chunks were submitted.""" | |
return self.current_mean | |
def _count_items_being_added_to_mean(self, data : NDArray): | |
if self.axis is None: | |
return np.size(data) | |
elif isinstance(self.axis, int): | |
return data.shape[self.axis] | |
else: | |
try: | |
return np.prod(np.array(data.shape)[list(self.axis)]) | |
except TypeError: | |
raise TypeError("axis should be None or an integer or a tuple of ints") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment