Skip to content

Instantly share code, notes, and snippets.

@ahmadia
Last active December 17, 2015 04:29
Show Gist options
  • Save ahmadia/5550933 to your computer and use it in GitHub Desktop.
Save ahmadia/5550933 to your computer and use it in GitHub Desktop.
numba for USR kernel (dirty)
from numba import autojit
import numpy as np
import numpy.testing as npt
import time
import gc
def usr(x, y, s=0.9, n=10):
scores = 1.0 / (1.0 + 1/12.0 * np.abs(x-y).sum(axis=1))
scores = scores[scores>=s]
scores.sort()
return scores[-n:][::-1]
def usr_numba(x, y, S, num_best):
m, n = x.shape
best = np.zeros(num_best)
best_low = 0.0
for i in xrange(m):
d = abs(x[i,0]-y[0])
for j in xrange(1,n):
d += abs(x[i,j] - y[j])
d = 1.0 / (1.0 + 1/12.0 *d)
if d > best_low and d > S:
k = 0
for k in xrange(0,num_best):
if d > best[k]:
break
for l in xrange(num_best-1, k, -1):
best[l] = best[l-1]
best[k] = d
best_low = best[num_best-1]
return best
_usr = autojit()(usr_numba)
def test_kernel():
N = 100
m = 12
y = np.random.randn(m)
x = 0.1*np.random.randn(N,m) + y
s1 = usr(x,y,0.9,10)
s2 = usr_numba(x,y,0.9,10)
npt.assert_array_almost_equal(s1,s2)
s3 = _usr(x,y,0.9,10)
npt.assert_array_almost_equal(s1,s3)
npt.assert_array_almost_equal(s2,s3)
print 'jitted kernel checks out'
def time_kernel(func, N=100, trials=3, dtype=np.double):
m = 12
y = np.random.randn(m)
x = 0.1*np.random.randn(N,m) + y
gcold = gc.isenabled()
gc.disable()
tic = time.time()
for i in xrange(trials):
func(x,y,0.9,10)
toc = time.time()-tic
if gcold:
gc.enable()
return (toc/trials,)
def test_timer(N=1000):
trials = 3
dtype = np.double
s1, = time_kernel(usr, N, trials, dtype)
print("")
print "N = %d" % N
print("usr (s): %g" % (s1))
s2, = time_kernel(_usr, N, trials, dtype)
print("numba (s): %g" % (s2))
print("%.2gX speedup" % (s1/s2))
test_kernel()
test_timer(1e6)
test_timer(2e6)
test_timer(4e6)
test_timer(8e6)
test_timer(1e7)
test_timer(2e7)
@ahmadia
Copy link
Author

ahmadia commented May 9, 2013

(performance on an i7 2.6 GhZ (8 GB of RAM):

jitted kernel checks out

N = 1000000
usr (s): 0.233645
numba (s): 0.0115586
20X speedup

N = 2000000
usr (s): 0.566954
numba (s): 0.023487
24X speedup

N = 4000000
usr (s): 1.14992
numba (s): 0.0472016
24X speedup

N = 8000000
usr (s): 2.34968
numba (s): 0.092601
25X speedup

N = 10000000
usr (s): 2.96395
numba (s): 0.116032
26X speedup

N = 20000000
usr (s): 17.4779
numba (s): 0.236304
74X speedup

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment