Skip to content

Instantly share code, notes, and snippets.

@gmarkall
Created April 8, 2021 09:27
Show Gist options
  • Save gmarkall/46b0b797762d5fe79ada4130cea07218 to your computer and use it in GitHub Desktop.
Save gmarkall/46b0b797762d5fe79ada4130cea07218 to your computer and use it in GitHub Desktop.
Fast Walsh Hadamard Transform code from Wikipedia accelerated with Numba
import numpy as np
from numba import njit
from time import perf_counter
# From https://en.wikipedia.org/wiki/Fast_Walsh%E2%80%93Hadamard_transform
def fwht(a) -> None:
"""In-place Fast Walsh–Hadamard Transform of array a."""
h = 1
while h < len(a):
for i in range(0, len(a), h * 2):
for j in range(i, i + h):
x = a[j]
y = a[j + h]
a[j] = x + y
a[j + h] = x - y
h *= 2
def benchmark(m):
print(f"Benchmarking with m = {m}, n = {2 ** m}")
x = np.random.random(2 ** m)
# Pure Python version
print("Running Python version...")
x_python = x.copy()
start = perf_counter()
fwht(x_python)
end = perf_counter()
python_time = end - start
# njit version
print("Running njit version...")
# Run once to warm up the compilation cache
njit_fwht = njit(fwht)
x_warmup = x.copy()
njit_fwht(x_warmup)
x_njit = x.copy()
start = perf_counter()
njit_fwht(x_njit)
end = perf_counter()
njit_time = end - start
print("Validating output")
np.testing.assert_equal(x_python, x_njit)
print(f"Python time {python_time}")
print(f"Njit time: {njit_time}")
if __name__ == '__main__':
import sys
if len(sys.argv) > 1:
m = int(sys.argv[1])
else:
m = 20
benchmark(m)
$ python fwht.py
Benchmarking with m = 20, n = 1048576
Running Python version...
Running njit version...
Validating output
Python time 5.71815178997349
Njit time: 0.014272961998358369
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment