Skip to content

Instantly share code, notes, and snippets.

@apowers313
Created October 10, 2024 14:19
Show Gist options
  • Save apowers313/99501f12d8e8babf59e6b567c0e21a2c to your computer and use it in GitHub Desktop.
Save apowers313/99501f12d8e8babf59e6b567c0e21a2c to your computer and use it in GitHub Desktop.
CUDA Python Simple Graph
from typing import Any
import numpy as np
from cuda import cuda, cudart, nvrtc
cuda_code = """
extern "C" __global__ void simple() {
printf("this is a test\\n");
}
"""
prog_name = "simple"
grid = (1, 1, 1)
block = (1, 1, 1)
device_id = 0
# Error checking helper
def checkCudaErrors(result: tuple[Any, ...]) -> Any:
def _cudaGetErrorEnum(error: Any) -> Any:
if isinstance(error, cuda.CUresult):
err, name = cuda.cuGetErrorName(error)
return name if err == cuda.CUresult.CUDA_SUCCESS else "<unknown>"
elif isinstance(error, nvrtc.nvrtcResult):
return nvrtc.nvrtcGetErrorString(error)[1]
else:
raise RuntimeError("Unknown error type: {}".format(error))
if result[0].value:
raise RuntimeError(
"CUDA error code={}({})".format(result[0].value, _cudaGetErrorEnum(result[0]))
)
if len(result) == 1:
return None
elif len(result) == 2:
return result[1]
else:
return result[1:]
# Init CUDA
checkCudaErrors(cuda.cuInit(0))
# Create device
nv_device = checkCudaErrors(cuda.cuDeviceGet(device_id))
# Create context
nv_context = checkCudaErrors(cuda.cuCtxCreate(0, nv_device))
# Create stream
nv_stream = checkCudaErrors(cuda.cuStreamCreate(cuda.CUstream_flags.CU_STREAM_DEFAULT))
# Create program
nv_prog = checkCudaErrors(nvrtc.nvrtcCreateProgram(cuda_code.encode(), b"test.cu", 0, [], []))
# Compile code
compile_result = checkCudaErrors(nvrtc.nvrtcCompileProgram(nv_prog, 0, []))
# Get PTX from compilation
nv_ptx_size = checkCudaErrors(nvrtc.nvrtcGetPTXSize(nv_prog))
ptx = b" " * nv_ptx_size
checkCudaErrors(nvrtc.nvrtcGetPTX(nv_prog, ptx))
# Load PTX as module data
ptx = np.char.array(ptx)
ret = cuda.cuModuleLoadData(ptx.ctypes.data)
nv_module = checkCudaErrors(ret)
# Get kernel from module
nv_kernel = checkCudaErrors(cuda.cuModuleGetFunction(nv_module, prog_name.encode()))
# Create graph
nv_graph = checkCudaErrors(cuda.cuGraphCreate(0))
# Create graph kernel node
nv_args = 0
nv_kernel_node_params = cuda.CUDA_KERNEL_NODE_PARAMS()
nv_kernel_node_params.func = nv_kernel
nv_kernel_node_params.gridDimX = grid[0]
nv_kernel_node_params.gridDimY = grid[1]
nv_kernel_node_params.gridDimZ = grid[2]
nv_kernel_node_params.blockDimX = block[0]
nv_kernel_node_params.blockDimY = block[1]
nv_kernel_node_params.blockDimZ = block[2]
nv_kernel_node_params.sharedMemBytes = 0
nv_kernel_node_params.kernelParams = nv_args
kern_node = checkCudaErrors(cuda.cuGraphAddKernelNode(nv_graph, None, 0, nv_kernel_node_params))
# Launch graph
print("*** LAUNCHING GRAPH ***")
nv_graph_exec = checkCudaErrors(cudart.cudaGraphInstantiate(nv_graph, 0))
checkCudaErrors(cudart.cudaGraphLaunch(nv_graph_exec, nv_stream))
# Synchronize with device before exiting
checkCudaErrors(cuda.cuStreamSynchronize(nv_stream))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment