Created
October 10, 2024 14:19
-
-
Save apowers313/99501f12d8e8babf59e6b567c0e21a2c to your computer and use it in GitHub Desktop.
CUDA Python Simple Graph
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Any | |
import numpy as np | |
from cuda import cuda, cudart, nvrtc | |
cuda_code = """ | |
extern "C" __global__ void simple() { | |
printf("this is a test\\n"); | |
} | |
""" | |
prog_name = "simple" | |
grid = (1, 1, 1) | |
block = (1, 1, 1) | |
device_id = 0 | |
# Error checking helper | |
def checkCudaErrors(result: tuple[Any, ...]) -> Any: | |
def _cudaGetErrorEnum(error: Any) -> Any: | |
if isinstance(error, cuda.CUresult): | |
err, name = cuda.cuGetErrorName(error) | |
return name if err == cuda.CUresult.CUDA_SUCCESS else "<unknown>" | |
elif isinstance(error, nvrtc.nvrtcResult): | |
return nvrtc.nvrtcGetErrorString(error)[1] | |
else: | |
raise RuntimeError("Unknown error type: {}".format(error)) | |
if result[0].value: | |
raise RuntimeError( | |
"CUDA error code={}({})".format(result[0].value, _cudaGetErrorEnum(result[0])) | |
) | |
if len(result) == 1: | |
return None | |
elif len(result) == 2: | |
return result[1] | |
else: | |
return result[1:] | |
# Init CUDA | |
checkCudaErrors(cuda.cuInit(0)) | |
# Create device | |
nv_device = checkCudaErrors(cuda.cuDeviceGet(device_id)) | |
# Create context | |
nv_context = checkCudaErrors(cuda.cuCtxCreate(0, nv_device)) | |
# Create stream | |
nv_stream = checkCudaErrors(cuda.cuStreamCreate(cuda.CUstream_flags.CU_STREAM_DEFAULT)) | |
# Create program | |
nv_prog = checkCudaErrors(nvrtc.nvrtcCreateProgram(cuda_code.encode(), b"test.cu", 0, [], [])) | |
# Compile code | |
compile_result = checkCudaErrors(nvrtc.nvrtcCompileProgram(nv_prog, 0, [])) | |
# Get PTX from compilation | |
nv_ptx_size = checkCudaErrors(nvrtc.nvrtcGetPTXSize(nv_prog)) | |
ptx = b" " * nv_ptx_size | |
checkCudaErrors(nvrtc.nvrtcGetPTX(nv_prog, ptx)) | |
# Load PTX as module data | |
ptx = np.char.array(ptx) | |
ret = cuda.cuModuleLoadData(ptx.ctypes.data) | |
nv_module = checkCudaErrors(ret) | |
# Get kernel from module | |
nv_kernel = checkCudaErrors(cuda.cuModuleGetFunction(nv_module, prog_name.encode())) | |
# Create graph | |
nv_graph = checkCudaErrors(cuda.cuGraphCreate(0)) | |
# Create graph kernel node | |
nv_args = 0 | |
nv_kernel_node_params = cuda.CUDA_KERNEL_NODE_PARAMS() | |
nv_kernel_node_params.func = nv_kernel | |
nv_kernel_node_params.gridDimX = grid[0] | |
nv_kernel_node_params.gridDimY = grid[1] | |
nv_kernel_node_params.gridDimZ = grid[2] | |
nv_kernel_node_params.blockDimX = block[0] | |
nv_kernel_node_params.blockDimY = block[1] | |
nv_kernel_node_params.blockDimZ = block[2] | |
nv_kernel_node_params.sharedMemBytes = 0 | |
nv_kernel_node_params.kernelParams = nv_args | |
kern_node = checkCudaErrors(cuda.cuGraphAddKernelNode(nv_graph, None, 0, nv_kernel_node_params)) | |
# Launch graph | |
print("*** LAUNCHING GRAPH ***") | |
nv_graph_exec = checkCudaErrors(cudart.cudaGraphInstantiate(nv_graph, 0)) | |
checkCudaErrors(cudart.cudaGraphLaunch(nv_graph_exec, nv_stream)) | |
# Synchronize with device before exiting | |
checkCudaErrors(cuda.cuStreamSynchronize(nv_stream)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment