Skip to content

Instantly share code, notes, and snippets.

@apowers313
Created October 10, 2024 14:23
Show Gist options
  • Save apowers313/58f24c5d28903a11e8bcc5d60ffa02cf to your computer and use it in GitHub Desktop.
Save apowers313/58f24c5d28903a11e8bcc5d60ffa02cf to your computer and use it in GitHub Desktop.
CUDA Python Two Node Graph - Memcpy and Kernel with Arguments
import ctypes
from typing import Any
import numpy as np
from cuda import cuda, cudart, nvrtc
cuda_code = """
extern "C" __global__ void simple(char *str) {
printf("this is a test\\n");
printf("passed argument was: %s\\n", str);
}
"""
str_arg = "hello from host"
prog_name = "simple"
grid = (1, 1, 1)
block = (1, 1, 1)
device_id = 0
# Error checking helper
def checkCudaErrors(result: tuple[Any, ...]) -> Any:
def _cudaGetErrorEnum(error: Any) -> Any:
if isinstance(error, cuda.CUresult):
err, name = cuda.cuGetErrorName(error)
return name if err == cuda.CUresult.CUDA_SUCCESS else "<unknown>"
elif isinstance(error, nvrtc.nvrtcResult):
return nvrtc.nvrtcGetErrorString(error)[1]
else:
raise RuntimeError("Unknown error type: {}".format(error))
if result[0].value:
raise RuntimeError(
"CUDA error code={}({})".format(result[0].value, _cudaGetErrorEnum(result[0]))
)
if len(result) == 1:
return None
elif len(result) == 2:
return result[1]
else:
return result[1:]
# Init CUDA
checkCudaErrors(cuda.cuInit(0))
# Create device
nv_device = checkCudaErrors(cuda.cuDeviceGet(device_id))
# Create context
nv_context = checkCudaErrors(cuda.cuCtxCreate(0, nv_device))
# Create stream
nv_stream = checkCudaErrors(cuda.cuStreamCreate(cuda.CUstream_flags.CU_STREAM_DEFAULT))
# Create program
nv_prog = checkCudaErrors(nvrtc.nvrtcCreateProgram(cuda_code.encode(), b"test.cu", 0, [], []))
# Compile code
compile_result = checkCudaErrors(nvrtc.nvrtcCompileProgram(nv_prog, 0, []))
# Get PTX from compilation
nv_ptx_size = checkCudaErrors(nvrtc.nvrtcGetPTXSize(nv_prog))
ptx = b" " * nv_ptx_size
checkCudaErrors(nvrtc.nvrtcGetPTX(nv_prog, ptx))
# Load PTX as module data
ptx = np.char.array(ptx)
ret = cuda.cuModuleLoadData(ptx.ctypes.data)
nv_module = checkCudaErrors(ret)
# Get kernel from module
nv_kernel = checkCudaErrors(cuda.cuModuleGetFunction(nv_module, prog_name.encode()))
# Create string argument
str_arg_buffer = bytearray(str_arg.encode())
str_arg_buffer.append(0) # trailing nul for C string, not really sure if this is necessary
str_arg_len = len(str_arg) + 1
# Allocate device memory
# TODO: cuMemAlloc causes cuLaunchKernel to fail with code=700(b'CUDA_ERROR_ILLEGAL_ADDRESS')
# nv_device_memory = checkCudaErrors(cuda.cuMemAlloc(str_arg_len))
nv_device_memory = checkCudaErrors(cudart.cudaMalloc(str_arg_len))
# Create graph
nv_graph = checkCudaErrors(cuda.cuGraphCreate(0))
# Create graph kernel node
arg_data = [nv_device_memory]
arg_types = [ctypes.c_void_p]
nv_args = (tuple(arg_data), tuple(arg_types))
nv_kernel_node_params = cuda.CUDA_KERNEL_NODE_PARAMS()
nv_kernel_node_params.func = nv_kernel
nv_kernel_node_params.gridDimX = grid[0]
nv_kernel_node_params.gridDimY = grid[1]
nv_kernel_node_params.gridDimZ = grid[2]
nv_kernel_node_params.blockDimX = block[0]
nv_kernel_node_params.blockDimY = block[1]
nv_kernel_node_params.blockDimZ = block[2]
nv_kernel_node_params.sharedMemBytes = 0
nv_kernel_node_params.kernelParams = nv_args
kern_node = checkCudaErrors(cuda.cuGraphAddKernelNode(nv_graph, None, 0, nv_kernel_node_params))
# Create memcpy node to copy string to device
memcpy_node = checkCudaErrors(
cudart.cudaGraphAddMemcpyNode1D(
nv_graph,
None,
0,
nv_device_memory,
str_arg_buffer,
str_arg_len,
cudart.cudaMemcpyKind.cudaMemcpyHostToDevice,
)
)
# Kernel node depends on memcpy node
checkCudaErrors(cudart.cudaGraphAddDependencies(nv_graph, [memcpy_node], [kern_node], 1))
# Launch graph
print("*** LAUNCHING GRAPH ***")
nv_graph_exec = checkCudaErrors(cudart.cudaGraphInstantiate(nv_graph, 0))
checkCudaErrors(cudart.cudaGraphLaunch(nv_graph_exec, nv_stream))
# Synchronize with device before exiting
checkCudaErrors(cuda.cuStreamSynchronize(nv_stream))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment