Last active
September 1, 2020 07:11
-
-
Save rjenc29/e0b15089d10a2b9045c397793f22f35a to your computer and use it in GitHub Desktop.
numba ewma - speed comparison
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
from numba import njit | |
import time | |
@njit | |
def ewma_version_1(x, halflife): | |
decay_coefficient = np.exp(np.log(0.5) / halflife) | |
out = np.empty_like(x, dtype=np.float64) | |
for i in range(out.shape[0]): | |
if i == 0: | |
out[i] = x[i] | |
sum_prior = 1 | |
first_weight = 1 | |
else: | |
first_weight *= decay_coefficient | |
sum_i = sum_prior + first_weight | |
out[i] = (decay_coefficient * out[i - 1] * sum_prior + x[i]) / sum_i | |
sum_prior = sum_i | |
return out | |
@njit | |
def ewma_version_2(x, halflife): | |
decay_coefficient = np.exp(np.log(0.5) / halflife) | |
out = np.empty_like(x, dtype=np.float64) | |
for i in range(out.shape[0]): | |
if i == 0: | |
# this bit is different | |
for j in range(x.shape[1]): | |
out[i, j] = x[i, j] | |
sum_prior = 1 | |
first_weight = 1 | |
else: | |
first_weight *= decay_coefficient | |
sum_i = sum_prior + first_weight | |
# so is this | |
for j in range(x.shape[1]): | |
out[i, j] = (decay_coefficient * out[i - 1, j] * sum_prior + x[i, j]) / sum_i | |
sum_prior = sum_i | |
return out | |
if __name__ == '__main__': | |
# create some sample data | |
rng = np.random.RandomState(0) | |
data = np.random.randn(5_000, 10) | |
halflife = 10 | |
# check the output meets expectations | |
output_1 = ewma_version_1(data, halflife) | |
output_2 = ewma_version_2(data, halflife) | |
expected = pd.DataFrame(data).ewm(halflife=halflife).mean().values | |
np.testing.assert_allclose(output_1, expected) | |
np.testing.assert_allclose(output_2, expected) | |
# time version 1 | |
s = time.perf_counter() | |
ewma_version_1(data, halflife) | |
e = time.perf_counter() | |
elapsed_1 = 1_000_000 * (e - s) | |
print('ewma_version_1: {0:.2f}'.format(elapsed_1)) | |
# time version 2 | |
s = time.perf_counter() | |
ewma_version_2(data, halflife) | |
e = time.perf_counter() | |
elapsed_2 = 1_000_000 * (e - s) | |
print('ewma_version_2: {0:.2f}'.format(elapsed_2)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment