Last active
December 17, 2015 04:29
-
-
Save ahmadia/5550933 to your computer and use it in GitHub Desktop.
numba for USR kernel (dirty)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from numba import autojit | |
import numpy as np | |
import numpy.testing as npt | |
import time | |
import gc | |
def usr(x, y, s=0.9, n=10): | |
scores = 1.0 / (1.0 + 1/12.0 * np.abs(x-y).sum(axis=1)) | |
scores = scores[scores>=s] | |
scores.sort() | |
return scores[-n:][::-1] | |
def usr_numba(x, y, S, num_best): | |
m, n = x.shape | |
best = np.zeros(num_best) | |
best_low = 0.0 | |
for i in xrange(m): | |
d = abs(x[i,0]-y[0]) | |
for j in xrange(1,n): | |
d += abs(x[i,j] - y[j]) | |
d = 1.0 / (1.0 + 1/12.0 *d) | |
if d > best_low and d > S: | |
k = 0 | |
for k in xrange(0,num_best): | |
if d > best[k]: | |
break | |
for l in xrange(num_best-1, k, -1): | |
best[l] = best[l-1] | |
best[k] = d | |
best_low = best[num_best-1] | |
return best | |
_usr = autojit()(usr_numba) | |
def test_kernel(): | |
N = 100 | |
m = 12 | |
y = np.random.randn(m) | |
x = 0.1*np.random.randn(N,m) + y | |
s1 = usr(x,y,0.9,10) | |
s2 = usr_numba(x,y,0.9,10) | |
npt.assert_array_almost_equal(s1,s2) | |
s3 = _usr(x,y,0.9,10) | |
npt.assert_array_almost_equal(s1,s3) | |
npt.assert_array_almost_equal(s2,s3) | |
print 'jitted kernel checks out' | |
def time_kernel(func, N=100, trials=3, dtype=np.double): | |
m = 12 | |
y = np.random.randn(m) | |
x = 0.1*np.random.randn(N,m) + y | |
gcold = gc.isenabled() | |
gc.disable() | |
tic = time.time() | |
for i in xrange(trials): | |
func(x,y,0.9,10) | |
toc = time.time()-tic | |
if gcold: | |
gc.enable() | |
return (toc/trials,) | |
def test_timer(N=1000): | |
trials = 3 | |
dtype = np.double | |
s1, = time_kernel(usr, N, trials, dtype) | |
print("") | |
print "N = %d" % N | |
print("usr (s): %g" % (s1)) | |
s2, = time_kernel(_usr, N, trials, dtype) | |
print("numba (s): %g" % (s2)) | |
print("%.2gX speedup" % (s1/s2)) | |
test_kernel() | |
test_timer(1e6) | |
test_timer(2e6) | |
test_timer(4e6) | |
test_timer(8e6) | |
test_timer(1e7) | |
test_timer(2e7) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
(performance on an i7 2.6 GhZ (8 GB of RAM):
jitted kernel checks out
N = 1000000
usr (s): 0.233645
numba (s): 0.0115586
20X speedup
N = 2000000
usr (s): 0.566954
numba (s): 0.023487
24X speedup
N = 4000000
usr (s): 1.14992
numba (s): 0.0472016
24X speedup
N = 8000000
usr (s): 2.34968
numba (s): 0.092601
25X speedup
N = 10000000
usr (s): 2.96395
numba (s): 0.116032
26X speedup
N = 20000000
usr (s): 17.4779
numba (s): 0.236304
74X speedup