Created
October 10, 2024 14:20
-
-
Save apowers313/1e93fb1a6ba9a9ac5c2ac42c8e8087a8 to your computer and use it in GitHub Desktop.
CUDA Python Simple Graph With Argument
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ctypes | |
from typing import Any | |
import numpy as np | |
from cuda import cuda, cudart, nvrtc | |
cuda_code = """ | |
extern "C" __global__ void simple(char *str) { | |
printf("this is a test\\n"); | |
printf("passed argument was: %s\\n", str); | |
} | |
""" | |
str_arg = "hello from host" | |
prog_name = "simple" | |
grid = (1, 1, 1) | |
block = (1, 1, 1) | |
device_id = 0 | |
# Error checking helper | |
def checkCudaErrors(result: tuple[Any, ...]) -> Any: | |
def _cudaGetErrorEnum(error: Any) -> Any: | |
if isinstance(error, cuda.CUresult): | |
err, name = cuda.cuGetErrorName(error) | |
return name if err == cuda.CUresult.CUDA_SUCCESS else "<unknown>" | |
elif isinstance(error, nvrtc.nvrtcResult): | |
return nvrtc.nvrtcGetErrorString(error)[1] | |
else: | |
raise RuntimeError("Unknown error type: {}".format(error)) | |
if result[0].value: | |
raise RuntimeError( | |
"CUDA error code={}({})".format(result[0].value, _cudaGetErrorEnum(result[0])) | |
) | |
if len(result) == 1: | |
return None | |
elif len(result) == 2: | |
return result[1] | |
else: | |
return result[1:] | |
# Init CUDA | |
checkCudaErrors(cuda.cuInit(0)) | |
# Create device | |
nv_device = checkCudaErrors(cuda.cuDeviceGet(device_id)) | |
# Create context | |
nv_context = checkCudaErrors(cuda.cuCtxCreate(0, nv_device)) | |
# Create stream | |
nv_stream = checkCudaErrors(cuda.cuStreamCreate(cuda.CUstream_flags.CU_STREAM_DEFAULT)) | |
# Create program | |
nv_prog = checkCudaErrors(nvrtc.nvrtcCreateProgram(cuda_code.encode(), b"test.cu", 0, [], [])) | |
# Compile code | |
compile_result = checkCudaErrors(nvrtc.nvrtcCompileProgram(nv_prog, 0, [])) | |
# Get PTX from compilation | |
nv_ptx_size = checkCudaErrors(nvrtc.nvrtcGetPTXSize(nv_prog)) | |
ptx = b" " * nv_ptx_size | |
checkCudaErrors(nvrtc.nvrtcGetPTX(nv_prog, ptx)) | |
# Load PTX as module data | |
ptx = np.char.array(ptx) | |
ret = cuda.cuModuleLoadData(ptx.ctypes.data) | |
nv_module = checkCudaErrors(ret) | |
# Get kernel from module | |
nv_kernel = checkCudaErrors(cuda.cuModuleGetFunction(nv_module, prog_name.encode())) | |
# Create string argument | |
str_arg_buffer = bytearray(str_arg.encode()) | |
str_arg_buffer.append(0) # trailing nul for C string, not really sure if this is necessary | |
str_arg_len = len(str_arg) + 1 | |
# Allocate device memory | |
# TODO: cuMemAlloc causes cuLaunchKernel to fail with code=700(b'CUDA_ERROR_ILLEGAL_ADDRESS') | |
# nv_device_memory = checkCudaErrors(cuda.cuMemAlloc(str_arg_len)) | |
nv_device_memory = checkCudaErrors(cudart.cudaMalloc(str_arg_len)) | |
# Copy string from host to device | |
cuda.cuMemcpyHtoD( | |
nv_device_memory, | |
str_arg_buffer, | |
str_arg_len, | |
) | |
# Create graph | |
nv_graph = checkCudaErrors(cuda.cuGraphCreate(0)) | |
# Create graph kernel node | |
arg_data = [nv_device_memory] | |
arg_types = [ctypes.c_void_p] | |
nv_args = (tuple(arg_data), tuple(arg_types)) | |
nv_kernel_node_params = cuda.CUDA_KERNEL_NODE_PARAMS() | |
nv_kernel_node_params.func = nv_kernel | |
nv_kernel_node_params.gridDimX = grid[0] | |
nv_kernel_node_params.gridDimY = grid[1] | |
nv_kernel_node_params.gridDimZ = grid[2] | |
nv_kernel_node_params.blockDimX = block[0] | |
nv_kernel_node_params.blockDimY = block[1] | |
nv_kernel_node_params.blockDimZ = block[2] | |
nv_kernel_node_params.sharedMemBytes = 0 | |
nv_kernel_node_params.kernelParams = nv_args | |
kern_node = checkCudaErrors(cuda.cuGraphAddKernelNode(nv_graph, None, 0, nv_kernel_node_params)) | |
# Launch graph | |
print("*** LAUNCHING GRAPH ***") | |
nv_graph_exec = checkCudaErrors(cudart.cudaGraphInstantiate(nv_graph, 0)) | |
checkCudaErrors(cudart.cudaGraphLaunch(nv_graph_exec, nv_stream)) | |
# Synchronize with device before exiting | |
checkCudaErrors(cuda.cuStreamSynchronize(nv_stream)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment