Last active
February 16, 2024 16:43
-
-
Save nullhook/11d74c02dc42e061ade9528973fae7f4 to your computer and use it in GitHub Desktop.
compute in metal
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#define NS_PRIVATE_IMPLEMENTATION | |
#define CA_PRIVATE_IMPLEMENTATION | |
#define MTL_PRIVATE_IMPLEMENTATION | |
#include "Metal.hpp" | |
MTL::Buffer* outputs; | |
MTL::Buffer* input0; | |
int main() { | |
// both represents a metal context | |
MTL::Device* device = MTL::CreateSystemDefaultDevice(); | |
MTL::CommandQueue* command_queue = device->newCommandQueue(); | |
MTL::Library* library = device->newDefaultLibrary(); | |
if (!library) assert(false); | |
MTL::Function* E_ = library->newFunction(NS::String::string("E_", NS::StringEncoding::UTF8StringEncoding)); | |
NS::Error* error = nullptr; | |
MTL::ComputePipelineState* pso = device->newComputePipelineState(E_, &error); | |
if (!pso) { | |
std::cerr << error->localizedDescription()->utf8String() << "\n"; | |
assert(false); | |
} | |
// all buffers are effectively shared on apple silicon | |
// due to unified memory architecture you can get away with | |
// forgetting to call didModifyRange. | |
// API contract says you must call didModifyRange, but the driver doesn’t enforce that on apple silicon | |
outputs = device->newBuffer(4, MTL::ResourceStorageModeManaged); | |
input0 = device->newBuffer(4, MTL::ResourceStorageModeManaged); | |
const float a{10.0}; | |
memcpy(input0->contents(), &a, sizeof(float)); | |
// input0->didModifyRange(NS::Range::Make( 0, sizeof(float))); | |
MTL::CommandBuffer* cmd_buff = command_queue->commandBuffer(); | |
MTL::ComputeCommandEncoder* cmd_enc = cmd_buff->computeCommandEncoder(); | |
// pass buffers to compute shaders | |
cmd_enc->setComputePipelineState(pso); | |
cmd_enc->setBuffer(outputs, 0, 0); | |
cmd_enc->setBuffer(input0, 0, 1); | |
cmd_enc->dispatchThreadgroups(MTL::Size({1, 1, 1}), MTL::Size({1, 1, 1})); | |
cmd_enc->endEncoding(); | |
// cmd_buff->addCompletedHandler([](const MTL::CommandBuffer* ignored) { | |
// float* out = static_cast<float*>(outputs->contents()); | |
// if (out != nullptr) { | |
// printf("%.2f\n", *out); | |
// } | |
// }); | |
cmd_buff->commit(); | |
// figure out runloop so we can read output buf | |
cmd_buff->waitUntilCompleted(); | |
float* out = static_cast<float*>(outputs->contents()); | |
if (out != nullptr) { | |
printf("%.2f\n", *out); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// gputrace via capture manager | |
// metal binary archive | |
#include <vector> | |
#include <iostream> | |
#include "sys/mman.h" | |
#include "sys/stat.h" | |
#define NS_PRIVATE_IMPLEMENTATION | |
#define CA_PRIVATE_IMPLEMENTATION | |
#define MTL_PRIVATE_IMPLEMENTATION | |
#include "lib/Metal.hpp" | |
#define max(x,y) ((x>y)?x:y) | |
#define int64 long | |
#define half __fp16 | |
#define uchar unsigned char | |
#define bool uchar | |
#define _NS_PRIVATE_SEL(accessor) (Private::Selector::s_k##accessor) | |
int main() { | |
// both represents a metal context | |
MTL::Device* device = MTL::CreateSystemDefaultDevice(); | |
MTL::CommandQueue* command_queue = device->newCommandQueue(); | |
// capture | |
bool success; | |
MTL::CaptureManager* captureManager = MTL::CaptureManager::sharedCaptureManager(); | |
success = captureManager->supportsDestination( MTL::CaptureDestinationGPUTraceDocument ); | |
if (!success) { | |
__builtin_printf( "Capture support is not enabled\n"); | |
assert( false ); | |
} | |
MTL::CaptureDescriptor* pCaptureDescriptor = MTL::CaptureDescriptor::alloc()->init(); | |
pCaptureDescriptor->setDestination( MTL::CaptureDestinationGPUTraceDocument ); | |
pCaptureDescriptor->setOutputURL( NS::URL::fileURLWithPath(NS::String::string("/tmp/compute.gputrace", NS::StringEncoding::ASCIIStringEncoding)) ); | |
pCaptureDescriptor->setCaptureObject(device); | |
NS::Error *pError = nullptr; | |
success = captureManager->startCapture( pCaptureDescriptor, &pError ); | |
std::cout << device->name()->utf8String() << "\n"; | |
const char* shaderSrc = R"( | |
#include <metal_stdlib> | |
using namespace metal; | |
kernel void E_(device unsigned char* data0, const device char* data1, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { | |
auto val0 = (long)(*(data1+0)); | |
auto val1 = (long)(*(data1+1)); | |
auto val2 = (long)(*(data1+2)); | |
auto val3 = (long)(*(data1+3)); | |
*(data0+0) = static_cast<uchar>(val0); | |
*(data0+1) = static_cast<uchar>(val1); | |
*(data0+2) = static_cast<uchar>(val2); | |
*(data0+3) = static_cast<uchar>(val3); | |
} | |
)"; | |
// const char* shaderSrc1 = R"( | |
// #include <metal_stdlib> | |
// using namespace metal; | |
// kernel void E_(device float* data0, const device float* data1, const device float* data2, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) { | |
// float val0 = *(data1+0); | |
// float val1 = *(data2+0); | |
// *(data0+0) = (val0+val1); | |
// } | |
// )"; | |
// HOW TO LOAD METAL BINARY ARCHIVE? | |
NS::Error* libError = nullptr; | |
MTL::CompileOptions* options = MTL::CompileOptions::alloc()->init(); | |
MTL::Library* library = device->newLibrary(NS::String::string(shaderSrc, NS::StringEncoding::UTF8StringEncoding), options, &libError); | |
if (!library) assert(false); | |
// NS::Error* e = nullptr; | |
// dispatch_data_t dispatch_data = dispatch_data_create(xlibrary, sizeof(*xlibrary), dispatch_get_main_queue(), NULL); | |
// MTL::Library* library = device->newLibrary(dispatch_data, &e); | |
// device->newLibrary() | |
// // BinaryArchive | |
// NS::Error* archive_err = nullptr; | |
// MTL::BinaryArchiveDescriptor* archive_desc = MTL::BinaryArchiveDescriptor::alloc()->init(); | |
// MTL::BinaryArchive* archive = device->newBinaryArchive(archive_desc, &archive_err); | |
// MTL::ComputePipelineDescriptor* compute_desc = MTL::ComputePipelineDescriptor::alloc()->init(); | |
// NS::Error* c_err = nullptr; | |
// compute_desc->setComputeFunction(library->newFunction(NS::String::string("E_", NS::StringEncoding::ASCIIStringEncoding))); | |
// archive->addComputePipelineFunctions(compute_desc, &c_err); | |
// NS::Error* url_Error = nullptr; | |
// auto url = NS::URL::alloc()->initFileURLWithPath(NS::String::string("/tmp/tmpyslkfs", NS::StringEncoding::UTF8StringEncoding)); | |
// bool is_success = archive->serializeToURL(url, &url_Error); | |
// if(!is_success) { | |
// std::cerr << "serializeToURL" << "\n"; | |
// assert(false); | |
// } | |
MTL::Function* E_ = library->newFunction(NS::String::string("E_", NS::StringEncoding::ASCIIStringEncoding)); | |
if (!E_) { | |
assert(false); | |
} | |
MTL::ComputePipelineDescriptor* compute_desc = MTL::ComputePipelineDescriptor::alloc()->init(); | |
compute_desc->setComputeFunction(E_); | |
NS::Error* error = nullptr; | |
MTL::ComputePipelineState* pso = device->newComputePipelineState(E_, &error); | |
if (!pso) { | |
std::cerr << error->localizedDescription()->utf8String() << "\n"; | |
assert(false); | |
} | |
MTL::CommandBuffer* cmd_buff = command_queue->commandBuffer(); | |
MTL::ComputeCommandEncoder* cmd_enc = cmd_buff->computeCommandEncoder(); | |
MTL::Buffer* outputs = device->newBuffer(4, MTL::ResourceStorageModeManaged); | |
MTL::Buffer* input0 = device->newBuffer(4, MTL::ResourceStorageModeManaged); | |
std::vector<char> a{-1, -2, -3, -4}; | |
memcpy(input0->contents(), a.data(), a.size() * sizeof(char)); | |
input0->didModifyRange(NS::Range::Make( 0, a.size() * sizeof(char))); | |
// // pass buffers to compute shaders | |
cmd_enc->setComputePipelineState(pso); | |
cmd_enc->setBuffer(outputs, 0, 0); | |
cmd_enc->setBuffer(input0, 0, 1); | |
cmd_enc->dispatchThreadgroups(MTL::Size({1, 1, 1}), MTL::Size({1, 1, 1})); | |
cmd_enc->endEncoding(); | |
// MTL::BlitCommandEncoder* bc_enc = cmd_buff->blitCommandEncoder(); | |
// bc_enc->synchronizeResource(outputs); | |
// bc_enc->endEncoding(); | |
cmd_buff->commit(); | |
cmd_buff->waitUntilCompleted(); | |
// // captureManager->stopCapture(); | |
auto* out = static_cast<unsigned char*>(outputs->contents()); | |
if (out != nullptr) { | |
for (int i=0; i<4; ++i) { | |
printf("%u, ", out[i]); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment