Created
January 10, 2015 12:52
-
-
Save Answeror/df564fa6a47c86f94ec9 to your computer and use it in GitHub Desktop.
Statistical accumulator in Python (http://stackoverflow.com/q/3774315/238472)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
import numpy as np | |
class Accumulators(object): | |
_accumulator_library = {} | |
def __init__(self): | |
self.accumulator_library = {} | |
for key, value in Accumulators._accumulator_library.items(): | |
self.accumulator_library[key] = value() | |
@staticmethod | |
def register(name, accumulator): | |
Accumulators._accumulator_library[name] = accumulator | |
def keys(self): | |
return self.accumulator_library.keys() | |
def add(self, x): | |
for accumulator in self.accumulator_library.values(): | |
accumulator.add(x) | |
def compute(self, name): | |
return self.accumulator_library[name].compute(self) | |
def __getattr__(self, name): | |
return self.compute(name) | |
@staticmethod | |
def register_decorator(name): | |
def wrap(cls): | |
Accumulators.register(name, cls) | |
return cls | |
return wrap | |
@Accumulators.register_decorator("min") | |
class Min(object): | |
def __init__(self): | |
self.min = None | |
def add(self, x): | |
self.min = x if self.min is None else np.min([self.min, x], axis=0) | |
def compute(self, host): | |
return self.min | |
@Accumulators.register_decorator("max") | |
class Max(object): | |
def __init__(self): | |
self.max = None | |
def add(self, x): | |
self.max = x if self.max is None else np.max([self.max, x], axis=0) | |
def compute(self, host): | |
return self.max | |
@Accumulators.register_decorator("count") | |
class Count(object): | |
def __init__(self): | |
self.count = 0 | |
def add(self, x): | |
self.count += 1 | |
def compute(self, host): | |
return self.count | |
@Accumulators.register_decorator("sum") | |
class Sum(object): | |
def __init__(self): | |
self.sum = 0 | |
def add(self, x): | |
self.sum += x | |
def compute(self, host): | |
return self.sum | |
@Accumulators.register_decorator("mean") | |
class Mean(object): | |
def __init__(self): | |
pass | |
def add(self, x): | |
pass | |
def compute(self, host): | |
return host.compute('sum') / host.compute('count') | |
@Accumulators.register_decorator("var") | |
class Var(object): | |
def __init__(self): | |
self.squared_sum = 0 | |
def add(self, x): | |
self.squared_sum += x ** 2 | |
def compute(self, host): | |
return self.squared_sum / host.compute('count') - host.compute('mean') ** 2 | |
@Accumulators.register_decorator("std") | |
class Std(object): | |
def __init__(self): | |
pass | |
def add(self, x): | |
pass | |
def compute(self, host): | |
return np.sqrt(host.compute('var')) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
import numpy as np | |
from nose.tools import assert_almost_equal | |
from numpy.testing import assert_allclose | |
from ..accumulators import Accumulators | |
def check_accumulators_scalar(name): | |
a = np.random.rand(42) | |
accu = Accumulators() | |
for x in a: | |
accu.add(x) | |
assert_almost_equal(getattr(np, name)(a), getattr(accu, name)) | |
def check_accumulators_vector(name): | |
a = np.random.rand(42, 13) | |
accu = Accumulators() | |
for x in a: | |
accu.add(x) | |
assert_allclose(getattr(np, name)(a, axis=0), getattr(accu, name)) | |
def check_accumulators(name): | |
check_accumulators_scalar(name) | |
check_accumulators_vector(name) | |
def test_accumulators(): | |
for name in ['min', 'max', 'mean', 'var', 'std']: | |
yield check_accumulators, name |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment