Skip to content

Instantly share code, notes, and snippets.

@josepsmartinez
Last active March 27, 2021 01:24
Show Gist options
  • Save josepsmartinez/ef246aad4b5757e157c6ef854c443c98 to your computer and use it in GitHub Desktop.
Save josepsmartinez/ef246aad4b5757e157c6ef854c443c98 to your computer and use it in GitHub Desktop.
Python class for online computing of `np.mean`
import numpy as np
class OnlineStatsCompute():
def __init__(self, *l):
if len(l) > 0:
self.add_l(*l)
def add_e(self, e):
"""Includes a single element `e` in the OnlineStatsCompute.
Updates attributes `N` and `mean`.
Parameters
----------
e : iter
"""
e = np.array(e)
if not hasattr(self, 'mean'):
self.mean = e
self.N = 1
else:
if e.shape != self.mean.shape:
raise ValueError(f"Expected shape {self.mean.shape}, got shape {e.shape}")
self.mean *= self.N
self.N += 1
self.mean += e
self.mean = self.mean / self.N
def add_l(self, *l):
"""Includes a sequence of elements `l` in the OnlineStatsCompute.
Equivalent to applying `self.add_e(e) for e in l`, but faster.
Parameters
----------
l : iter
"""
if not hasattr(self, 'mean'):
self.mean = np.mean(l, axis=0)
self.N = len(l)
else:
l = np.array(l)
if l.shape[1] != self.mean.shape[0]:
raise ValueError(f"Expected shape {self.mean.shape}, got shape {l.shape}")
self.mean *= self.N
self.N += l.shape[0]
self.mean += np.sum(l, axis=0)
self.mean = self.mean / self.N
def test():
gold = np.mean([
[1,2,3],
[3,2,1],
[2,2,2],
[10, 4, 6]
], axis=0)
a = OnlineStatsCompute()
a.add_e([1, 2, 3])
try:
a.add_e([3, 2, 1, 1])
except ValueError:
pass
a.add_e([3, 2, 1])
a.add_e([2, 2, 2])
a.add_e([10, 4, 6])
assert (a.mean == gold).all()
b = (
[1, 2, 3],
[3, 2, 1],
[2, 2, 2],
[10, 4, 6]
)
assert (b.mean == gold).all()
c = OnlineStatsCompute(
[1, 2, 3],
[3, 2, 1]
)
c.add_l(
[2, 2, 2],
[10, 4, 6]
)
assert (c.mean == gold).all()
test() if __name__ == "__main__" else True
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment