Created
February 19, 2021 11:59
-
-
Save gmarkall/09e4cbfe6fda4f7a35be07c8dd926f36 to your computer and use it in GitHub Desktop.
Calling PyCUDA kernels on CuPy arrays using the CUDA array interface
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pycuda.driver as cuda | |
import pycuda.autoinit # noqa | |
from collections import namedtuple | |
from pycuda.compiler import SourceModule | |
import cupy as cp | |
# PyCUDA will try to get a pointer to data from an object it doesn't recognise | |
# from its gpudata attribute | |
pycuda_wrapper = namedtuple('pycuda_wrapper', ('gpudata')) | |
# Create a CuPy array (and a copy for comparison later) | |
cupy_a = cp.random.randn(4, 4).astype(cp.float32) | |
original = cupy_a.copy() | |
# Wrap our CuPy array in a PyCUDA wrapper. We use get the pointer to the data | |
# from CuPy using the CUDA Array Interface: | |
# https://numba.readthedocs.io/en/latest/cuda/cuda_array_interface.html | |
pycuda_a = pycuda_wrapper(cupy_a.__cuda_array_interface__['data'][0]) | |
# Create a kernel | |
mod = SourceModule(""" | |
__global__ void doublify(float *a) | |
{ | |
int idx = threadIdx.x + threadIdx.y*4; | |
a[idx] *= 2; | |
} | |
""") | |
func = mod.get_function("doublify") | |
# Invoke PyCUDA kernel on wrapped CuPy data | |
func(pycuda_a, block=(4, 4, 1), grid=(1, 1), shared=0) | |
# Demonstrate that our CuPy array was modified in place by the PyCUDA kernel | |
print("original array:") | |
print(original) | |
print("doubled with kernel:") | |
print(cupy_a) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
$ python cupy_pycuda_arrays.py | |
original array: | |
[[-2.2803636 -0.11322732 1.3498526 -0.56024694] | |
[-0.41064784 0.6268035 0.10337123 -0.05056423] | |
[ 0.0339125 -1.5568337 1.1931331 1.6248434 ] | |
[ 0.37393996 0.0133023 0.95385337 0.7686548 ]] | |
doubled with kernel: | |
[[-4.560727 -0.22645465 2.6997051 -1.1204939 ] | |
[-0.8212957 1.253607 0.20674247 -0.10112845] | |
[ 0.067825 -3.1136675 2.3862662 3.2496867 ] | |
[ 0.7478799 0.02660459 1.9077067 1.5373096 ]] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment