-
-
Save andrewssobral/1d6661b422d65fb5500b4551bd5515fd to your computer and use it in GitHub Desktop.
Compile and call a Metal GPU kernel from Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Requires: | |
# pip install pyobjc-framework-Metal | |
import numpy as np | |
import Metal | |
# Get the default GPU device | |
device = Metal.MTLCreateSystemDefaultDevice() | |
# Make a command queue to encode command buffers to | |
command_queue = device.newCommandQueue() | |
# Compile the source code into a library | |
library, err = device.newLibraryWithSource_options_error_( | |
r""" | |
[[kernel]] void add( | |
device const float* a, | |
device const float* b, | |
device float* c, | |
uint index [[thread_position_in_grid]]) { | |
c[index] = a[index] + b[index]; | |
} | |
""", None, None) | |
if err: | |
print(err) | |
exit(1) | |
# Get the compiled "add" kernel | |
function = library.newFunctionWithName_("add") | |
kernel, err = device.newComputePipelineStateWithFunction_error_(function, None) | |
if err: | |
print(err) | |
exit(1) | |
# Make the command buffer and compute command encoder | |
command_buffer = command_queue.commandBuffer() | |
compute_encoder = command_buffer.computeCommandEncoder() | |
# Setup the problem data | |
n = 4096 | |
a = np.random.uniform(size=(n,)).astype(np.float32) | |
b = np.random.uniform(size=(n,)).astype(np.float32) | |
def np_to_mtl_buffer(x): | |
opts = Metal.MTLResourceOptions(Metal.MTLResourceStorageModeShared) | |
return device.newBufferWithBytes_length_options_( | |
memoryview(x).tobytes(), x.nbytes, opts, | |
) | |
def mtl_buffer(size): | |
opts = Metal.MTLResourceOptions(Metal.MTLResourceStorageModeShared) | |
return device.newBufferWithLength_options_(size, opts) | |
def mtl_buffer_to_np(buf): | |
return np.frombuffer(buf.contents().as_buffer(buf.length()), dtype=np.float32) | |
# Dispatch the kernel with the correct number of threads | |
compute_encoder.setComputePipelineState_(kernel) | |
grid_dims = Metal.MTLSize(n, 1, 1) | |
group_dims = Metal.MTLSize(1024, 1, 1) | |
a_buf = np_to_mtl_buffer(a) | |
b_buf = np_to_mtl_buffer(b) | |
c_buf = mtl_buffer(a.nbytes) | |
compute_encoder.setBuffer_offset_atIndex_(a_buf, 0, 0) | |
compute_encoder.setBuffer_offset_atIndex_(b_buf, 0, 1) | |
compute_encoder.setBuffer_offset_atIndex_(c_buf, 0, 2) | |
compute_encoder.dispatchThreads_threadsPerThreadgroup_(grid_dims, group_dims); | |
# End the encoding and commit the buffer | |
compute_encoder.endEncoding() | |
command_buffer.commit() | |
# Wait until the computation is finished | |
command_buffer.waitUntilCompleted() | |
c = mtl_buffer_to_np(c_buf) | |
print(a + b) | |
print(c) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment