Skip to content

Instantly share code, notes, and snippets.

@maedoc
Created August 30, 2018 21:50
Show Gist options
  • Save maedoc/3417c87bd964378044ad0cb813d3b396 to your computer and use it in GitHub Desktop.
Save maedoc/3417c87bd964378044ad0cb813d3b396 to your computer and use it in GitHub Desktop.
Loop fusion in Numba
import math
import numpy as np
import numba as nb
@nb.njit
def k1(a):
for i in range(a.size):
a[i] += 1
@nb.njit
def k2(a):
for i in range(a.size):
a[i] *= 2
@nb.njit
def k3(a):
k1(a)
k2(a)
k1(a)
k2(a)
@nb.njit
def k123(a):
for i in range(a.size):
a[i] = 2 * (2 * (a[i] + 1) + 1)
# a.nbytes > L3 cache size
a = np.random.randn(1000000)
# let jit compile
k3(a)
k123(a)
print('k3')
%timeit k3(a) # 1780 us
print('k123')
%timeit k123(a) # 485 us
# hand-fused loop is 4x faster
# a different test case
@nb.njit
def l1(a):
out = np.empty_like(a)
for i in range(a.size):
out[i] = math.sin(a[i])
return out
@nb.njit
def l2(a):
out = np.empty_like(a)
for i in range(a.size):
out[i] = math.cos(a[i])
return out
@nb.njit
def l3(a):
return l1(l2(l1(l2(a))))
@nb.njit
def l123(a):
out = np.empty_like(a)
for i in range(a.size):
out[i] = math.sin(math.cos(math.sin(math.cos(a[i]))))
return out
# a.nbytes > L3 cache size
a = np.random.randn(1000000)
# let jit compile
l3(a);
l123(a);
print('l3')
%timeit l3(a) # 50 ms
print('l123')
%timeit l123(a) # 70 ms
# loop appears fused
@maedoc
Copy link
Author

maedoc commented Aug 30, 2018

w/ Numba version 0.39.0

@maedoc
Copy link
Author

maedoc commented Aug 30, 2018

Changing the k1 - 3 loops to out allocating style as in l1 - 3 results in two loops being fused (k3 2x slower than k123) but overall 2.5x slower than in place modification.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment