Created
August 9, 2019 20:15
-
-
Save pentschev/c1443d3fd9f5eb088541b1df1a8f5e4f to your computer and use it in GitHub Desktop.
Blog Post - Parallelizing Custom CuPy Kernels with Dask - Complete
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dask.distributed import Client | |
from dask_cuda import LocalCUDACluster | |
from dask.array.utils import assert_eq | |
import dask.array as da | |
import cupy | |
add_broadcast_kernel = cupy.RawKernel( | |
r''' | |
extern "C" __global__ | |
void add_broadcast_kernel( | |
const float* x, const float* y, float* z, | |
const int xdim0, const int zdim0) | |
{ | |
int idx0 = blockIdx.x * blockDim.x + threadIdx.x; | |
int idx1 = blockIdx.y * blockDim.y + threadIdx.y; | |
z[idx1 * zdim0 + idx0] = x[idx1 * xdim0 + idx0] + y[idx0]; | |
} | |
''', | |
'add_broadcast_kernel' | |
) | |
def dispatch_add_broadcast(x, y): | |
block_size = (32, 32) | |
grid_size = (x.shape[1] // block_size[1], x.shape[0] // block_size[0]) | |
z = cupy.empty(x.shape, x.dtype) | |
xdim0 = x.strides[0] // x.strides[1] | |
zdim0 = z.strides[0] // z.strides[1] | |
add_broadcast_kernel(grid_size, block_size, (x, y, z, xdim0, zdim0)) | |
return z | |
if __name__ == "__main__": | |
cluster = LocalCUDACluster() | |
client = Client(cluster) | |
x = cupy.arange(4096 * 1024, dtype=cupy.float32).reshape((4096, 1024)) | |
y = cupy.arange(1024, dtype=cupy.float32).reshape(1, 1024) | |
res_cupy = x + y | |
res_add_broadcast = dispatch_add_broadcast(x, y) | |
assert_eq(res_cupy, res_add_broadcast) | |
dx = da.from_array(x, chunks=(1024, 512), asarray=False) | |
dy = da.from_array(y, chunks=(1, 512), asarray=False) | |
res = da.map_blocks(dispatch_add_broadcast, dx, dy, dtype=dx.dtype) | |
res = res.compute() | |
assert_eq(res, res_cupy) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment