Skip to content

Instantly share code, notes, and snippets.

@msaroufim
Created February 26, 2025 19:13
Show Gist options
  • Save msaroufim/399490c301e9df1c9998c851462cb09d to your computer and use it in GitHub Desktop.
Save msaroufim/399490c301e9df1c9998c851462cb09d to your computer and use it in GitHub Desktop.
import torch
from torch.utils.cpp_extension import load_inline
cpp_code = """
torch::Tensor to_gray(torch::Tensor input);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("to_gray", &to_gray, "Convert RGB to Grayscale (CUDA)");
}
"""
cuda_kernel_code = """
torch::Tensor to_gray(torch::Tensor input) {
auto output = torch::empty({input.size(0), input.size(1)}, input.options());
return output;
}
"""
cuda_module = load_inline(
name="to_gray_cuda",
cpp_sources=cpp_code,
cuda_sources=cuda_kernel_code,
functions=["to_gray"],
with_cuda=True,
verbose=True,
extra_cflags=["-std=c++17", "-ftime-report"],
extra_cuda_cflags=["-arch=sm_90"]
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment