Last active
December 15, 2023 22:53
-
-
Save smrfeld/ab02f4be59b6f0e4f7455d1a728e31a0 to your computer and use it in GitHub Desktop.
C++/Obj-C extension calling Metal shader
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <torch/extension.h> | |
#include <Metal/Metal.h> | |
#include <Foundation/Foundation.h> | |
#include <iostream> | |
// Define a function to add tensors using Metal | |
torch::Tensor add_tensors_metal(torch::Tensor a, torch::Tensor b, const std::string& shaderFilePath) { | |
// Check that device is MPS | |
if (a.device().type() != torch::kMPS || b.device().type() != torch::kMPS) { | |
throw std::runtime_error("Error: tensors must be on MPS device."); | |
} | |
// Check that tensors are contiguous | |
// Contiguous means that the memory is contiguous | |
a = a.contiguous(); | |
b = b.contiguous(); | |
// Get the total number of elements in the tensors | |
int numElements = a.numel(); | |
// Get the default Metal device | |
id<MTLDevice> device = MTLCreateSystemDefaultDevice(); | |
// Load the Metal shader from the specified path | |
NSError* error = nil; | |
NSString* shaderSource = [ | |
NSString stringWithContentsOfFile:[NSString stringWithUTF8String:shaderFilePath.c_str()] | |
encoding:NSUTF8StringEncoding | |
error:&error]; | |
if (error) { | |
throw std::runtime_error("Failed to load Metal shader: " + std::string(error.localizedDescription.UTF8String)); | |
} | |
// Compile the Metal shader source | |
id<MTLLibrary> library = [device newLibraryWithSource:shaderSource options:nil error:&error]; | |
if (!library) { | |
throw std::runtime_error("Error compiling Metal shader: " + std::string(error.localizedDescription.UTF8String)); | |
} | |
id<MTLFunction> function = [library newFunctionWithName:@"addTensors"]; | |
if (!function) { | |
throw std::runtime_error("Error: Metal function addTensors not found."); | |
} | |
// Create a Metal compute pipeline state | |
id<MTLComputePipelineState> pipelineState = [device newComputePipelineStateWithFunction:function error:nil]; | |
// Create Metal buffers for the tensors | |
id<MTLBuffer> aBuffer = [device newBufferWithBytes:a.data_ptr() length:(numElements * sizeof(float)) options:MTLResourceStorageModeShared]; | |
id<MTLBuffer> bBuffer = [device newBufferWithBytes:b.data_ptr() length:(numElements * sizeof(float)) options:MTLResourceStorageModeShared]; | |
id<MTLBuffer> resultBuffer = [device newBufferWithLength:(numElements * sizeof(float)) options:MTLResourceStorageModeShared]; | |
// Create a command queue | |
id<MTLCommandQueue> commandQueue = [device newCommandQueue]; | |
// Create a command buffer | |
id<MTLCommandBuffer> commandBuffer = [commandQueue commandBuffer]; | |
// Create a compute command encoder | |
id<MTLComputeCommandEncoder> encoder = [commandBuffer computeCommandEncoder]; | |
// Set the compute pipeline state | |
[encoder setComputePipelineState:pipelineState]; | |
// Set the buffers | |
[encoder setBuffer:aBuffer offset:0 atIndex:0]; | |
[encoder setBuffer:bBuffer offset:0 atIndex:1]; | |
[encoder setBuffer:resultBuffer offset:0 atIndex:2]; | |
// Dispatch the compute kernel | |
MTLSize gridSize = MTLSizeMake(numElements, 1, 1); | |
NSUInteger threadGroupSize = pipelineState.maxTotalThreadsPerThreadgroup; | |
if (threadGroupSize > numElements) { | |
threadGroupSize = numElements; | |
} | |
MTLSize threadgroupSize = MTLSizeMake(threadGroupSize, 1, 1); | |
[encoder dispatchThreads:gridSize threadsPerThreadgroup:threadgroupSize]; | |
[encoder endEncoding]; | |
// Commit the command buffer and wait for it to complete | |
[commandBuffer commit]; | |
[commandBuffer waitUntilCompleted]; | |
// Create an empty tensor on the MPS device to hold the result | |
torch::Tensor result = torch::empty({numElements}, torch::TensorOptions().dtype(torch::kFloat).device(torch::kMPS)); | |
// Copy the result from the Metal buffer to the MPS tensor | |
id<MTLBuffer> resultBufferMPS = [device newBufferWithBytesNoCopy:result.data_ptr() | |
length:(numElements * sizeof(float)) | |
options:MTLResourceStorageModeShared | |
deallocator:nil]; | |
return result; | |
} | |
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |
m.def("add_tensors_metal", &add_tensors_metal, "Add two tensors using Metal"); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment