Skip to content

Instantly share code, notes, and snippets.

@alvinwan
Last active October 30, 2024 02:29
Show Gist options
  • Save alvinwan/f7bb0cdd26c018f40052f9944fc5c679 to your computer and use it in GitHub Desktop.
Save alvinwan/f7bb0cdd26c018f40052f9944fc5c679 to your computer and use it in GitHub Desktop.
How to write a minimal, standalone Python script to run Metal (GPU) kernels on Mac

Running Metal via Python only, on Mac

Minimal Python-only script for Mac, running a "Hello World" Metal kernel. No need for Xcode, a Swift app, or PyTorch.

# Step 0: Clone the gist
git clone https://gist.github.com/f7bb0cdd26c018f40052f9944fc5c679.git

# Step 1: Install prerequisites
pip install pyobjc

# Step 2: Run the file
python MyMetalKernel.py

You're now done with the Metal example.

  • For how to use the Metal library, find the associated objective-C functions and documentation in Apple Metal documentation.
  • Any function or class in the official Objective-C Metal library is somehow bound to an object in the Python Metal library, made available via pyobjc. See pyobjc documentation for additional help.

For reference, the above script was tested on Ventura 13.2.1 with Python 3.11.5 on pyobjc=10.0.

"""
"Hello world" example of using Metal from Python.
This script can be run from the command line just like any other Python file. No
need for Xcode or any other IDE. Just make sure you have the latest version of
Python 3 installed, along with the PyObjC and pyobjc-framework-Metal packages.
"""
import Metal
import ctypes
import random
from math import log
#####################################
# 1. Setup the Metal kernel itself.
#####################################
# Define a Metal kernel function
kernel_source = """
#include <metal_stdlib>
using namespace metal;
kernel void log_kernel(device float *in [[ buffer(0) ]],
device float *out [[ buffer(1) ]],
uint id [[ thread_position_in_grid ]]) {
out[id] = log(in[id]);
}
"""
# Create a Metal device, library, and kernel function
device = Metal.MTLCreateSystemDefaultDevice()
library = device.newLibraryWithSource_options_error_(kernel_source, None, None)[0]
kernel_function = library.newFunctionWithName_("log_kernel")
#########################################
# 2. Setup the input and output buffers.
#########################################
# Create input and output buffers
array_length = 1024
buffer_length = array_length * 4 # 4 bytes per float
input_buffer = device.newBufferWithLength_options_(buffer_length, Metal.MTLResourceStorageModeShared)
output_buffer = device.newBufferWithLength_options_(buffer_length, Metal.MTLResourceStorageModeShared)
# Populate input buffer with random values
input_list = [random.uniform(0.0, 1.0) for _ in range(array_length)] # Create list of random numbers
input_array = (ctypes.c_float * array_length).from_buffer(input_buffer.contents().as_buffer(buffer_length)) # Map the Metal buffer to a Python array
input_array[:] = input_list # Populate the arrays with random values
#####################################
# 3. Call the Metal kernel function.
#####################################
# Create a command queue and command buffer
commandQueue = device.newCommandQueue()
commandBuffer = commandQueue.commandBuffer()
# Set the kernel function and buffers
pso = device.newComputePipelineStateWithFunction_error_(kernel_function, None)[0]
computeEncoder = commandBuffer.computeCommandEncoder()
computeEncoder.setComputePipelineState_(pso)
computeEncoder.setBuffer_offset_atIndex_(input_buffer, 0, 0)
computeEncoder.setBuffer_offset_atIndex_(output_buffer, 0, 1)
# Define threadgroup size
threadsPerThreadgroup = Metal.MTLSizeMake(1024, 1, 1)
threadgroupSize = Metal.MTLSizeMake(pso.maxTotalThreadsPerThreadgroup(), 1, 1)
# Dispatch the kernel
computeEncoder.dispatchThreads_threadsPerThreadgroup_(threadsPerThreadgroup, threadgroupSize)
computeEncoder.endEncoding()
# Commit the command buffer
commandBuffer.commit()
commandBuffer.waitUntilCompleted()
##################################
# 4. Check the output is correct.
##################################
# Map the Metal buffer to a Python array
output_data = (ctypes.c_float * array_length).from_buffer(output_buffer.contents().as_buffer(buffer_length))
output_list = list(output_data)
# Check the outputs are correct
output_python = [log(x) for x in input_list]
assert all([abs(a - b) < 1e-5 for a, b in zip(output_list, output_python)]), "❌ Output does not match reference!"
print("✅ Reference matches output!")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment