Skip to content

Instantly share code, notes, and snippets.

@saethlin
Created September 9, 2019 22:02
Show Gist options
  • Save saethlin/01867d6dfbd17be8f96ef6629ce1ecbd to your computer and use it in GitHub Desktop.
Save saethlin/01867d6dfbd17be8f96ef6629ce1ecbd to your computer and use it in GitHub Desktop.
import numpy as np
import numba
import time
def add_numpy(a, b):
return np.sum(a**2 + b**2)
@numba.njit(parallel=True, fastmath=True)
def add_numba(a, b):
result = 0
for i in numba.prange(len(a)):
result += a[i]**2 + b[i]**2
return result
@numba.njit(fastmath=True, nogil=True)
def histogram(data, bins):
binned = np.zeros(bins.size, dtype=np.int64)
for d in range(len(data)):
for b in range(len(bins)):
if data[d] < bins[b]:
binned[b] += 1
break
return binned
def parallel_histogram(data, bins):
from concurrent.futures import ThreadPoolExecutor
import multiprocessing
futures = []
with ThreadPoolExecutor() as executor:
chunksize = data.size // multiprocessing.cpu_count()
for start in range(0, data.size, chunksize):
futures.append(executor.submit(histogram, data[start:start+chunksize], bins))
binned = np.zeros(bins.size, dtype=np.int64)
import tqdm
for f in futures:
binned += f.result()
return binned
x = np.random.rand(1_000_000)
y = np.random.rand(1_000_000)
start = time.time()
for _ in range(100):
add_numpy(x, y)
print('numpy', time.time() - start)
add_numba(x, y)
start = time.time()
for _ in range(100):
add_numba(x, y)
print('numba', time.time() - start)
np_version, edges = np.histogram(x, bins=np.linspace(0, 1, 101))
nb_version = histogram(x, np.linspace(0, 1, 101)[1:])
par_nb_version = parallel_histogram(x, np.linspace(0, 1, 101)[1:])
print(nb_version - np_version)
print(nb_version - par_nb_version)
start = time.time()
for _ in range(10):
np.histogram(x, bins=np.linspace(0, 1, 101))
print('numpy', time.time() - start)
histogram(x, bins=np.linspace(0, 1, 101)[1:])
start = time.time()
for _ in range(10):
histogram(x, bins=np.linspace(0, 1, 101)[1:])
print('numba', time.time() - start)
start = time.time()
for _ in range(10):
parallel_histogram(x, bins=np.linspace(0, 1, 101)[1:])
print('parallel_numba', time.time() - start)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment