Created
December 21, 2024 18:20
-
-
Save peterc/4e91a3405d87615bf4bf5b7fb9382ed8 to your computer and use it in GitHub Desktop.
Multiply two numbers using the GPU via Metal on macOS
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Metal | |
// Inline Metal shader code | |
let shaderSource = """ | |
kernel void multiply(const device float* a [[ buffer(0) ]], | |
const device float* b [[ buffer(1) ]], | |
device float* result [[ buffer(2) ]], | |
uint id [[ thread_position_in_grid ]]) { | |
result[id] = a[id] * b[id]; | |
} | |
""" | |
func main() { | |
// Step 1: Get the first available Metal device | |
guard let device = MTLCopyAllDevices().first else { | |
print("No Metal devices available.") | |
return | |
} | |
print("Using Metal device: \(device.name)") | |
// Step 2: Create Metal library and compute pipeline | |
guard let library = try? device.makeLibrary(source: shaderSource, options: nil) else { | |
print("Failed to create Metal library.") | |
return | |
} | |
guard let function = library.makeFunction(name: "multiply") else { | |
print("Failed to create Metal function.") | |
return | |
} | |
guard let pipeline = try? device.makeComputePipelineState(function: function) else { | |
print("Failed to create Metal pipeline.") | |
return | |
} | |
// Step 3: Create command queue | |
guard let commandQueue = device.makeCommandQueue() else { | |
print("Failed to create command queue.") | |
return | |
} | |
// Step 4: Prepare data buffers | |
let a: [Float] = [2.0, 3.0] | |
let b: [Float] = [4.0, 5.0] | |
var result: [Float] = [0.0, 0.0] | |
let bufferA = device.makeBuffer(bytes: a, length: MemoryLayout<Float>.size * a.count, options: .storageModeShared) | |
let bufferB = device.makeBuffer(bytes: b, length: MemoryLayout<Float>.size * b.count, options: .storageModeShared) | |
let bufferResult = device.makeBuffer(bytes: &result, length: MemoryLayout<Float>.size * result.count, options: .storageModeShared) | |
// Step 5: Create and encode commands | |
guard let commandBuffer = commandQueue.makeCommandBuffer(), | |
let encoder = commandBuffer.makeComputeCommandEncoder() else { | |
print("Failed to create command buffer or encoder.") | |
return | |
} | |
encoder.setComputePipelineState(pipeline) | |
encoder.setBuffer(bufferA, offset: 0, index: 0) | |
encoder.setBuffer(bufferB, offset: 0, index: 1) | |
encoder.setBuffer(bufferResult, offset: 0, index: 2) | |
let gridSize = MTLSize(width: a.count, height: 1, depth: 1) | |
let threadGroupSize = MTLSize(width: 1, height: 1, depth: 1) | |
encoder.dispatchThreads(gridSize, threadsPerThreadgroup: threadGroupSize) | |
encoder.endEncoding() | |
// Step 6: Execute commands | |
commandBuffer.commit() | |
commandBuffer.waitUntilCompleted() | |
// Step 7: Fetch and print results | |
let resultPointer = bufferResult?.contents().bindMemory(to: Float.self, capacity: result.count) | |
let finalResult = Array(UnsafeBufferPointer(start: resultPointer, count: result.count)) | |
print("Results: \(finalResult)") | |
} | |
// Run the main function | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment