Created
January 28, 2019 21:28
-
-
Save stsievert/ce8f8835040744c8d77d3a14f5d0c2b7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from scipy.signal import convolve | |
def test(type1, type2): | |
np.random.seed(42) | |
n = 3000 | |
if 'int' in type1 or 'bool' in type1: | |
x1 = np.random.choice([0, 1], size=n).astype(type1) | |
else: | |
x1 = np.random.randn(n).astype(type1) | |
if 'int' in type2 or 'bool' in type2: | |
x2 = np.random.choice([0, 1], size=n).astype(type2) | |
else: | |
x2 = np.random.randn(n).astype(type2) | |
y = {method: convolve(x1, x2, method=method) for method in ['fft', 'direct']} | |
try: | |
assert np.allclose(y['fft'], y['direct'], atol=1e-2) | |
except: | |
error = np.abs(y['fft'] - y['direct']) | |
print(type1, type2, np.median(error), error.max()) | |
if __name__ == "__main__": | |
types = ['bool', 'uint8', 'uint16', 'uint32', 'uint64', 'int8', 'int16', 'int32', 'int64', 'float16', 'float32'] | |
for t1 in types: | |
for t2 in types: | |
test(t1, t2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment