Last active
March 21, 2025 23:38
-
-
Save malfet/6f52de932aed35e046952f7e054294df to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Example showing how to use the no_header mode with a TensorBase CUDA extension | |
This example creates a CUDA extension that directly includes ATen/core/TensorBase.h | |
instead of torch/extension.h, resulting in faster compilation with no_header=True | |
""" | |
from datetime import datetime | |
import torch | |
import torch.utils.cpp_extension | |
# C++ code that directly includes TensorBase.h without using torch/extension.h | |
cpp_source = """ | |
#include <ATen/core/TensorBase.h> | |
#include <ATen/cuda/EmptyTensor.h> | |
#include <c10/cuda/CUDAGuard.h> | |
#include <torch/library.h> | |
#include <Python.h> | |
// Forward declaration of the CUDA kernel function | |
void launch_add_kernel(const float *, const float *, float *, int64_t); | |
at::Tensor tensor_base_add(const at::Tensor& x, const at::Tensor& y) { | |
// Validate inputs | |
TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor"); | |
TORCH_CHECK(y.is_cuda(), "y must be a CUDA tensor"); | |
TORCH_CHECK(x.scalar_type() == at::ScalarType::Float, "x must be a float tensor"); | |
TORCH_CHECK(y.scalar_type() == at::ScalarType::Float, "y must be a float tensor"); | |
TORCH_CHECK(x.sizes() == y.sizes(), "x and y must have the same shape"); | |
TORCH_CHECK(x.is_contiguous() && y.is_contiguous(), "x and y must be contiguous tensors"); | |
TORCH_CHECK(x.numel() <= std::numeric_limits<int>::max(), "640Kb ought to be enough for everybody!"); | |
auto output = at::detail::empty_cuda(x.sizes(), x.scalar_type(), x.device(), std::nullopt); | |
// Set CUDA device and launch kernel | |
const at::cuda::CUDAGuard device_guard(x.device()); | |
launch_add_kernel(x.const_data_ptr<float>(), y.const_data_ptr<float>(), output.data_ptr<float>(), x.numel()); | |
return output; | |
} | |
TORCH_LIBRARY(my_ops, m) { | |
m.def("tensor_base_add", &tensor_base_add); | |
} | |
PyMODINIT_FUNC PyInit_noname(void) { | |
static struct PyModuleDef foo = {PyModuleDef_HEAD_INIT, "noname", nullptr, -1, nullptr}; | |
return PyModule_Create(&foo); | |
} | |
""" | |
# CUDA source with direct TensorBase usage | |
cuda_source = """ | |
__global__ void tensor_base_add_kernel(const float* __restrict__ x, | |
const float* __restrict__ y, | |
float* __restrict__ out, | |
const int size) { | |
const int idx = blockIdx.x * blockDim.x + threadIdx.x; | |
if (idx < size) { | |
out[idx] = x[idx] + y[idx] + 1.0f; // Add 1 to make it distinguishable | |
} | |
} | |
void launch_add_kernel(const float *x_data, const float *y_data, float *output_data, int64_t num_elements) { | |
const int threads = 1024; | |
const int blocks = (num_elements + threads - 1) / threads; | |
tensor_base_add_kernel<<<blocks, threads>>>( | |
x_data, y_data, output_data, num_elements); | |
} | |
""" | |
def main(): | |
print("Compiling TensorBase CUDA extension with no_header=True...") | |
# Load the extension using load_inline with no_header=True | |
start_time = datetime.now() | |
module = torch.utils.cpp_extension.load_inline( | |
name="noname", | |
cpp_sources=cpp_source, | |
cuda_sources=cuda_source, | |
no_header=True, # Skip including torch/extension.h | |
) | |
end_time = datetime.now()-start_time | |
print(f"Extension compiled successfully, end_time={end_time}!") | |
# Test the functionality | |
print("Testing on CUDA tensors...") | |
x = torch.randn(100, device="cuda", dtype=torch.float32) | |
y = torch.randn(100, device="cuda", dtype=torch.float32) | |
# Call our custom kernel | |
result = torch.ops.my_ops.tensor_base_add(x, y) | |
# Verify result (our kernel adds 1.0 to distinguish it from a regular add) | |
expected = x + y + 1.0 | |
# Check if results match | |
if torch.allclose(result, expected): | |
print("Test PASSED! ✓") | |
print(f"First few elements of tensors:") | |
print(f"x: {x[:5]}") | |
print(f"y: {y[:5]}") | |
print(f"result: {result[:5]}") | |
print(f"expected: {expected[:5]}") | |
else: | |
print("Test FAILED!") | |
max_diff = torch.max(torch.abs(result - expected)) | |
print(f"Maximum difference: {max_diff}") | |
if __name__ == "__main__": | |
# Check if CUDA is available | |
if not torch.cuda.is_available(): | |
print("CUDA is not available, this example requires CUDA") | |
else: | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment