Skip to content

Instantly share code, notes, and snippets.

Forked from awni/
Created August 11, 2024 07:50
Show Gist options
  • Save andrewssobral/1d6661b422d65fb5500b4551bd5515fd to your computer and use it in GitHub Desktop.
Save andrewssobral/1d6661b422d65fb5500b4551bd5515fd to your computer and use it in GitHub Desktop.
Compile and call a Metal GPU kernel from Python
# Requires:
# pip install pyobjc-framework-Metal
import numpy as np
import Metal
# Get the default GPU device
device = Metal.MTLCreateSystemDefaultDevice()
# Make a command queue to encode command buffers to
command_queue = device.newCommandQueue()
# Compile the source code into a library
library, err = device.newLibraryWithSource_options_error_(
[[kernel]] void add(
device const float* a,
device const float* b,
device float* c,
uint index [[thread_position_in_grid]]) {
c[index] = a[index] + b[index];
""", None, None)
if err:
# Get the compiled "add" kernel
function = library.newFunctionWithName_("add")
kernel, err = device.newComputePipelineStateWithFunction_error_(function, None)
if err:
# Make the command buffer and compute command encoder
command_buffer = command_queue.commandBuffer()
compute_encoder = command_buffer.computeCommandEncoder()
# Setup the problem data
n = 4096
a = np.random.uniform(size=(n,)).astype(np.float32)
b = np.random.uniform(size=(n,)).astype(np.float32)
def np_to_mtl_buffer(x):
opts = Metal.MTLResourceOptions(Metal.MTLResourceStorageModeShared)
return device.newBufferWithBytes_length_options_(
memoryview(x).tobytes(), x.nbytes, opts,
def mtl_buffer(size):
opts = Metal.MTLResourceOptions(Metal.MTLResourceStorageModeShared)
return device.newBufferWithLength_options_(size, opts)
def mtl_buffer_to_np(buf):
return np.frombuffer(buf.contents().as_buffer(buf.length()), dtype=np.float32)
# Dispatch the kernel with the correct number of threads
grid_dims = Metal.MTLSize(n, 1, 1)
group_dims = Metal.MTLSize(1024, 1, 1)
a_buf = np_to_mtl_buffer(a)
b_buf = np_to_mtl_buffer(b)
c_buf = mtl_buffer(a.nbytes)
compute_encoder.setBuffer_offset_atIndex_(a_buf, 0, 0)
compute_encoder.setBuffer_offset_atIndex_(b_buf, 0, 1)
compute_encoder.setBuffer_offset_atIndex_(c_buf, 0, 2)
compute_encoder.dispatchThreads_threadsPerThreadgroup_(grid_dims, group_dims);
# End the encoding and commit the buffer
# Wait until the computation is finished
c = mtl_buffer_to_np(c_buf)
print(a + b)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment