Skip to content

Instantly share code, notes, and snippets.

@rjenc29
Last active September 1, 2020 07:11
Show Gist options
  • Save rjenc29/e0b15089d10a2b9045c397793f22f35a to your computer and use it in GitHub Desktop.
Save rjenc29/e0b15089d10a2b9045c397793f22f35a to your computer and use it in GitHub Desktop.
numba ewma - speed comparison
import numpy as np
import pandas as pd
from numba import njit
import time
@njit
def ewma_version_1(x, halflife):
decay_coefficient = np.exp(np.log(0.5) / halflife)
out = np.empty_like(x, dtype=np.float64)
for i in range(out.shape[0]):
if i == 0:
out[i] = x[i]
sum_prior = 1
first_weight = 1
else:
first_weight *= decay_coefficient
sum_i = sum_prior + first_weight
out[i] = (decay_coefficient * out[i - 1] * sum_prior + x[i]) / sum_i
sum_prior = sum_i
return out
@njit
def ewma_version_2(x, halflife):
decay_coefficient = np.exp(np.log(0.5) / halflife)
out = np.empty_like(x, dtype=np.float64)
for i in range(out.shape[0]):
if i == 0:
# this bit is different
for j in range(x.shape[1]):
out[i, j] = x[i, j]
sum_prior = 1
first_weight = 1
else:
first_weight *= decay_coefficient
sum_i = sum_prior + first_weight
# so is this
for j in range(x.shape[1]):
out[i, j] = (decay_coefficient * out[i - 1, j] * sum_prior + x[i, j]) / sum_i
sum_prior = sum_i
return out
if __name__ == '__main__':
# create some sample data
rng = np.random.RandomState(0)
data = np.random.randn(5_000, 10)
halflife = 10
# check the output meets expectations
output_1 = ewma_version_1(data, halflife)
output_2 = ewma_version_2(data, halflife)
expected = pd.DataFrame(data).ewm(halflife=halflife).mean().values
np.testing.assert_allclose(output_1, expected)
np.testing.assert_allclose(output_2, expected)
# time version 1
s = time.perf_counter()
ewma_version_1(data, halflife)
e = time.perf_counter()
elapsed_1 = 1_000_000 * (e - s)
print('ewma_version_1: {0:.2f}'.format(elapsed_1))
# time version 2
s = time.perf_counter()
ewma_version_2(data, halflife)
e = time.perf_counter()
elapsed_2 = 1_000_000 * (e - s)
print('ewma_version_2: {0:.2f}'.format(elapsed_2))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment