Created
April 8, 2021 09:27
-
-
Save gmarkall/46b0b797762d5fe79ada4130cea07218 to your computer and use it in GitHub Desktop.
Fast Walsh Hadamard Transform code from Wikipedia accelerated with Numba
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from numba import njit | |
from time import perf_counter | |
# From https://en.wikipedia.org/wiki/Fast_Walsh%E2%80%93Hadamard_transform | |
def fwht(a) -> None: | |
"""In-place Fast Walsh–Hadamard Transform of array a.""" | |
h = 1 | |
while h < len(a): | |
for i in range(0, len(a), h * 2): | |
for j in range(i, i + h): | |
x = a[j] | |
y = a[j + h] | |
a[j] = x + y | |
a[j + h] = x - y | |
h *= 2 | |
def benchmark(m): | |
print(f"Benchmarking with m = {m}, n = {2 ** m}") | |
x = np.random.random(2 ** m) | |
# Pure Python version | |
print("Running Python version...") | |
x_python = x.copy() | |
start = perf_counter() | |
fwht(x_python) | |
end = perf_counter() | |
python_time = end - start | |
# njit version | |
print("Running njit version...") | |
# Run once to warm up the compilation cache | |
njit_fwht = njit(fwht) | |
x_warmup = x.copy() | |
njit_fwht(x_warmup) | |
x_njit = x.copy() | |
start = perf_counter() | |
njit_fwht(x_njit) | |
end = perf_counter() | |
njit_time = end - start | |
print("Validating output") | |
np.testing.assert_equal(x_python, x_njit) | |
print(f"Python time {python_time}") | |
print(f"Njit time: {njit_time}") | |
if __name__ == '__main__': | |
import sys | |
if len(sys.argv) > 1: | |
m = int(sys.argv[1]) | |
else: | |
m = 20 | |
benchmark(m) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
$ python fwht.py | |
Benchmarking with m = 20, n = 1048576 | |
Running Python version... | |
Running njit version... | |
Validating output | |
Python time 5.71815178997349 | |
Njit time: 0.014272961998358369 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment