Created
August 30, 2018 21:50
-
-
Save maedoc/3417c87bd964378044ad0cb813d3b396 to your computer and use it in GitHub Desktop.
Loop fusion in Numba
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import numpy as np | |
import numba as nb | |
@nb.njit | |
def k1(a): | |
for i in range(a.size): | |
a[i] += 1 | |
@nb.njit | |
def k2(a): | |
for i in range(a.size): | |
a[i] *= 2 | |
@nb.njit | |
def k3(a): | |
k1(a) | |
k2(a) | |
k1(a) | |
k2(a) | |
@nb.njit | |
def k123(a): | |
for i in range(a.size): | |
a[i] = 2 * (2 * (a[i] + 1) + 1) | |
# a.nbytes > L3 cache size | |
a = np.random.randn(1000000) | |
# let jit compile | |
k3(a) | |
k123(a) | |
print('k3') | |
%timeit k3(a) # 1780 us | |
print('k123') | |
%timeit k123(a) # 485 us | |
# hand-fused loop is 4x faster | |
# a different test case | |
@nb.njit | |
def l1(a): | |
out = np.empty_like(a) | |
for i in range(a.size): | |
out[i] = math.sin(a[i]) | |
return out | |
@nb.njit | |
def l2(a): | |
out = np.empty_like(a) | |
for i in range(a.size): | |
out[i] = math.cos(a[i]) | |
return out | |
@nb.njit | |
def l3(a): | |
return l1(l2(l1(l2(a)))) | |
@nb.njit | |
def l123(a): | |
out = np.empty_like(a) | |
for i in range(a.size): | |
out[i] = math.sin(math.cos(math.sin(math.cos(a[i])))) | |
return out | |
# a.nbytes > L3 cache size | |
a = np.random.randn(1000000) | |
# let jit compile | |
l3(a); | |
l123(a); | |
print('l3') | |
%timeit l3(a) # 50 ms | |
print('l123') | |
%timeit l123(a) # 70 ms | |
# loop appears fused |
Changing the k1 - 3 loops to out
allocating style as in l1 - 3 results in two loops being fused (k3 2x slower than k123) but overall 2.5x slower than in place modification.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
w/ Numba version 0.39.0