Skip to content

Instantly share code, notes, and snippets.

@malfet
Last active March 21, 2025 23:38
Show Gist options
  • Save malfet/6f52de932aed35e046952f7e054294df to your computer and use it in GitHub Desktop.
Save malfet/6f52de932aed35e046952f7e054294df to your computer and use it in GitHub Desktop.
"""
Example showing how to use the no_header mode with a TensorBase CUDA extension
This example creates a CUDA extension that directly includes ATen/core/TensorBase.h
instead of torch/extension.h, resulting in faster compilation with no_header=True
"""
from datetime import datetime
import torch
import torch.utils.cpp_extension
# C++ code that directly includes TensorBase.h without using torch/extension.h
cpp_source = """
#include <ATen/core/TensorBase.h>
#include <ATen/cuda/EmptyTensor.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>
#include <Python.h>
// Forward declaration of the CUDA kernel function
void launch_add_kernel(const float *, const float *, float *, int64_t);
at::Tensor tensor_base_add(const at::Tensor& x, const at::Tensor& y) {
// Validate inputs
TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor");
TORCH_CHECK(y.is_cuda(), "y must be a CUDA tensor");
TORCH_CHECK(x.scalar_type() == at::ScalarType::Float, "x must be a float tensor");
TORCH_CHECK(y.scalar_type() == at::ScalarType::Float, "y must be a float tensor");
TORCH_CHECK(x.sizes() == y.sizes(), "x and y must have the same shape");
TORCH_CHECK(x.is_contiguous() && y.is_contiguous(), "x and y must be contiguous tensors");
TORCH_CHECK(x.numel() <= std::numeric_limits<int>::max(), "640Kb ought to be enough for everybody!");
auto output = at::detail::empty_cuda(x.sizes(), x.scalar_type(), x.device(), std::nullopt);
// Set CUDA device and launch kernel
const at::cuda::CUDAGuard device_guard(x.device());
launch_add_kernel(x.const_data_ptr<float>(), y.const_data_ptr<float>(), output.data_ptr<float>(), x.numel());
return output;
}
TORCH_LIBRARY(my_ops, m) {
m.def("tensor_base_add", &tensor_base_add);
}
PyMODINIT_FUNC PyInit_noname(void) {
static struct PyModuleDef foo = {PyModuleDef_HEAD_INIT, "noname", nullptr, -1, nullptr};
return PyModule_Create(&foo);
}
"""
# CUDA source with direct TensorBase usage
cuda_source = """
__global__ void tensor_base_add_kernel(const float* __restrict__ x,
const float* __restrict__ y,
float* __restrict__ out,
const int size) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < size) {
out[idx] = x[idx] + y[idx] + 1.0f; // Add 1 to make it distinguishable
}
}
void launch_add_kernel(const float *x_data, const float *y_data, float *output_data, int64_t num_elements) {
const int threads = 1024;
const int blocks = (num_elements + threads - 1) / threads;
tensor_base_add_kernel<<<blocks, threads>>>(
x_data, y_data, output_data, num_elements);
}
"""
def main():
print("Compiling TensorBase CUDA extension with no_header=True...")
# Load the extension using load_inline with no_header=True
start_time = datetime.now()
module = torch.utils.cpp_extension.load_inline(
name="noname",
cpp_sources=cpp_source,
cuda_sources=cuda_source,
no_header=True, # Skip including torch/extension.h
)
end_time = datetime.now()-start_time
print(f"Extension compiled successfully, end_time={end_time}!")
# Test the functionality
print("Testing on CUDA tensors...")
x = torch.randn(100, device="cuda", dtype=torch.float32)
y = torch.randn(100, device="cuda", dtype=torch.float32)
# Call our custom kernel
result = torch.ops.my_ops.tensor_base_add(x, y)
# Verify result (our kernel adds 1.0 to distinguish it from a regular add)
expected = x + y + 1.0
# Check if results match
if torch.allclose(result, expected):
print("Test PASSED! ✓")
print(f"First few elements of tensors:")
print(f"x: {x[:5]}")
print(f"y: {y[:5]}")
print(f"result: {result[:5]}")
print(f"expected: {expected[:5]}")
else:
print("Test FAILED!")
max_diff = torch.max(torch.abs(result - expected))
print(f"Maximum difference: {max_diff}")
if __name__ == "__main__":
# Check if CUDA is available
if not torch.cuda.is_available():
print("CUDA is not available, this example requires CUDA")
else:
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment