Created
July 15, 2020 08:27
-
-
Save coderforlife/e3a5fffa17be71c2f97779e0c41fb5a1 to your computer and use it in GitHub Desktop.
Testing `cupyx.scipy.ndimage.generic_filter()`
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy, cupy | |
import scipy.ndimage as ndi | |
import cupyx.scipy.ndimage as cp_ndi | |
from scipy import LowLevelCallable | |
from numba import cfunc, types, carray | |
##### Root Mean Squared ##### | |
# Actually these are just the mean-squared | |
rms_raw = cupy.RawKernel('''extern "C" __global__ | |
void rms(const double* x, int filter_size, double* y) { | |
double ss = 0; | |
for (int i = 0; i < filter_size; ++i) { ss += x[i]*x[i]; } | |
y[0] = ss/filter_size; | |
}''', 'rms') | |
rms_red = cupy.ReductionKernel('X x', 'Y y', 'x*x', 'a + b', 'y = a/_in_ind.size()', '0', 'rms') | |
def rms_fuse_wrapper(filter_size): | |
def rms_fuse(x): return (x*x).sum()/filter_size | |
return rms_fuse | |
@cfunc(types.intc(types.CPointer(types.double), types.intp, types.CPointer(types.double), types.voidptr)) | |
def rms_numba(x, filter_size, y, _): | |
ss = 0 | |
for i in range(filter_size): ss += x[i]*x[i] | |
y[0] = ss/filter_size | |
return 1 | |
rms_llc = LowLevelCallable(rms_numba.ctypes) | |
def rms_pyfunc(x): return (x*x).sum()/len(x) | |
##### Less-Than Middle ##### | |
lt_raw = cupy.RawKernel('''extern "C" __global__ | |
void lt(const double* x, int filter_size, double* y) { | |
int n = 0; | |
double c = x[filter_size/2]; | |
for (int i = 0; i < filter_size; ++i) { n += c>x[i]; } | |
y[0] = n; | |
}''', 'lt') | |
lt_red = cupy.ReductionKernel('X x', 'Y y', '_raw_x[_in_ind.size()/2]>x', 'a + b', 'y = a', '0', 'lt', reduce_type='int') | |
def lt_fuse_wrapper(filter_size): | |
def lt_fuse(x): return (x[filter_size//2]>x).sum() | |
return lt_fuse | |
@cfunc(types.intc(types.CPointer(types.double), types.intp, types.CPointer(types.double), types.voidptr)) | |
def lt_numba(x, filter_size, y, _): | |
c = x[filter_size//2] | |
n = 0 | |
for i in range(filter_size): n += c>x[i] | |
y[0] = n | |
return 1 | |
lt_llc = LowLevelCallable(lt_numba.ctypes) | |
def lt_pyfunc(x): return (x[len(x)//2]>x).sum() | |
##### All ##### | |
all_raw = cupy.RawKernel('''extern "C" __global__ | |
void all(const double* x, int filter_size, double* y) { | |
int n = 0; | |
for (int i = 0; i < filter_size; ++i) { n += x[i]!=0; } | |
y[0] = n; | |
}''', 'all') | |
all_red = cupy.ReductionKernel('X x', 'Y y', 'x!=0', 'a + b', 'y = a', '0', 'all', reduce_type='int') | |
all_fuse = cupy.all | |
@cfunc(types.intc(types.CPointer(types.double), types.intp, types.CPointer(types.double), types.voidptr)) | |
def all_numba(x, filter_size, y, _): | |
n = 0 | |
for i in range(filter_size): n += x[i]!=0 | |
y[0] = n | |
return 1 | |
all_llc = LowLevelCallable(all_numba.ctypes) | |
all_pyfunc = numpy.all | |
###### Setup for running tests ###### | |
funcs = [ | |
['rms', [rms_raw, rms_red, rms_fuse_wrapper], [rms_llc, rms_pyfunc]], | |
['lt', [lt_raw, lt_red, lt_fuse_wrapper], [lt_llc, lt_pyfunc]], | |
['all', [all_raw, all_red, all_fuse], [all_llc, all_pyfunc]], | |
] | |
cp_names = ['raw', 'red', 'fuse'] | |
sp_names = ['numba', 'py'] | |
###### Setup run timing tests ###### | |
sp_data = numpy.random.rand(1000, 1000) | |
cp_data = cupy.array(sp_data) | |
for size in [3, 15, 25]: | |
for name, cp_funcs, sp_funcs in funcs: | |
print(name, '%dx%d' % (size, size)) | |
for name, func in zip(cp_names, cp_funcs): | |
if func in (rms_fuse_wrapper, lt_fuse_wrapper): func = func(size*size) | |
out = cp_ndi.generic_filter(cp_data, func, size) | |
ref = ndi.generic_filter(sp_data, sp_funcs[0], size) | |
if numpy.allclose(out.get(), ref): | |
print(name, end=' ') | |
else: | |
print(name, '*', end=' ') # asterisks means bad result | |
%timeit cp_ndi.generic_filter(cp_data, func, size); cupy.cuda.Stream.null.synchronize() | |
for name, func in zip(sp_names, sp_funcs): | |
ndi.generic_filter(sp_data, func, size) | |
print(name, end=' ') | |
%timeit ndi.generic_filter(sp_data, func, size) | |
print('----------------------------------------') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Tested on system with a Intel Xeon Gold 5122 CPU @ 3.60GHz and a Titan V GPU. | |
The * for fuse with `all` indicates it is actually getting the wrong output, still need to fix that issue apparently. | |
rms 3x3 | |
raw 308 µs ± 893 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) | |
red 308 µs ± 206 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) | |
fuse 2.06 ms ± 4.13 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) | |
numba 14.1 ms ± 48.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) | |
py 2.96 s ± 15.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) | |
---------------------------------------- | |
lt 3x3 | |
raw 337 µs ± 940 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) | |
red 338 µs ± 1.16 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) | |
fuse 1.73 ms ± 11.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) | |
numba 16.3 ms ± 146 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) | |
py 3.91 s ± 391 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) | |
---------------------------------------- | |
all 3x3 | |
raw 339 µs ± 487 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) | |
red 339 µs ± 783 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) | |
fuse * 604 µs ± 1.62 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) | |
numba 16.2 ms ± 160 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) | |
py 2.72 s ± 5.35 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) | |
---------------------------------------- | |
rms 15x15 | |
raw 6.83 ms ± 806 ns per loop (mean ± std. dev. of 7 runs, 100 loops each) | |
red 6.83 ms ± 2.07 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) | |
fuse 15.5 ms ± 54.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) | |
numba 371 ms ± 486 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) | |
py 3.5 s ± 6.81 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) | |
---------------------------------------- | |
lt 15x15 | |
raw 6.8 ms ± 1.67 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) | |
red 6.8 ms ± 1.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) | |
fuse 9.36 ms ± 4.28 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) | |
numba 156 ms ± 676 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) | |
py 3.84 s ± 5.92 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) | |
---------------------------------------- | |
all 15x15 | |
raw 6.88 ms ± 1.27 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) | |
red 6.88 ms ± 1.81 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) | |
fuse * 7.12 ms ± 9.96 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) | |
numba 158 ms ± 225 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) | |
py 3.17 s ± 4.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) | |
---------------------------------------- | |
rms 25x25 | |
raw 18.7 ms ± 5.85 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) | |
red 18.7 ms ± 5.45 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) | |
fuse 39.1 ms ± 16 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) | |
numba 1.02 s ± 416 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) | |
py 4.25 s ± 25.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) | |
---------------------------------------- | |
lt 25x25 | |
raw 19.4 ms ± 9.67 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) | |
red 19.4 ms ± 8.77 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) | |
fuse 24.8 ms ± 12.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) | |
numba 406 ms ± 559 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) | |
py 4.82 s ± 10.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) | |
---------------------------------------- | |
all 25x25 | |
raw 19.5 ms ± 2.85 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) | |
red 19.5 ms ± 3.39 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) | |
fuse * 19.7 ms ± 4.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) | |
numba 405 ms ± 118 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) | |
py 3.89 s ± 4.63 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) | |
---------------------------------------- |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment