Last active
May 19, 2024 21:04
-
-
Save theoknock/5fdf7236c101d9d00448a2a96b51eb68 to your computer and use it in GitHub Desktop.
Executing Metal Compute Shaders: A barebones SwiftUI app that executes Metal Compute shaders and returns the output to a SwiftUI view for display.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import MetalKit | |
class MetalSineWaveGenerator { | |
let device: MTLDevice | |
let commandQueue: MTLCommandQueue | |
let computePipelineState: MTLComputePipelineState | |
let arraySize: Int | |
let frequency: Float | |
let sampleRate: Float | |
let resultBuffer: MTLBuffer | |
init?(arraySize: Int, frequency: Float, sampleRate: Float) { | |
guard let device = MTLCreateSystemDefaultDevice(), | |
let commandQueue = device.makeCommandQueue(), | |
let library = device.makeDefaultLibrary(), | |
let function = library.makeFunction(name: "sineWave"), | |
let computePipelineState = try? device.makeComputePipelineState(function: function) else { | |
return nil | |
} | |
self.device = device | |
self.commandQueue = commandQueue | |
self.computePipelineState = computePipelineState | |
self.arraySize = arraySize | |
self.frequency = frequency | |
self.sampleRate = sampleRate | |
self.resultBuffer = device.makeBuffer(length: arraySize * MemoryLayout<Float>.size, options: .storageModeShared)! | |
} | |
func generateSineWave() -> [Float] { | |
var frequency = self.frequency | |
var sampleRate = self.sampleRate | |
// Create a command buffer and encoder | |
let commandBuffer = commandQueue.makeCommandBuffer()! | |
let computeEncoder = commandBuffer.makeComputeCommandEncoder()! | |
computeEncoder.setComputePipelineState(computePipelineState) | |
computeEncoder.setBuffer(resultBuffer, offset: 0, index: 0) | |
computeEncoder.setBytes(&frequency, length: MemoryLayout<Float>.size, index: 1) | |
computeEncoder.setBytes(&sampleRate, length: MemoryLayout<Float>.size, index: 2) | |
// Dispatch the compute shader | |
let threadGroupSize = MTLSize(width: 256, height: 1, depth: 1) | |
let threadGroups = MTLSize(width: (arraySize + threadGroupSize.width - 1) / threadGroupSize.width, height: 1, depth: 1) | |
computeEncoder.dispatchThreadgroups(threadGroups, threadsPerThreadgroup: threadGroupSize) | |
// End encoding and commit the command buffer | |
computeEncoder.endEncoding() | |
commandBuffer.commit() | |
commandBuffer.waitUntilCompleted() | |
// Retrieve the results | |
let resultPointer = resultBuffer.contents().bindMemory(to: Float.self, capacity: arraySize) | |
return Array(UnsafeBufferPointer(start: resultPointer, count: arraySize)) | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Foundation | |
import MetalKit | |
class MetalSineWaveGenerator { | |
let device: MTLDevice | |
let commandQueue: MTLCommandQueue | |
let computePipelineState: MTLComputePipelineState | |
let arraySize: Int | |
let frequency: Float | |
let sampleRate: Float | |
init?(arraySize: Int, frequency: Float, sampleRate: Float) { | |
self.device = MTLCreateSystemDefaultDevice()! | |
self.commandQueue = device.makeCommandQueue()! | |
self.arraySize = arraySize | |
self.frequency = frequency | |
self.sampleRate = sampleRate | |
// Load the compute function from the shader | |
let library = device.makeDefaultLibrary()! | |
let function = library.makeFunction(name: "sineWave")! | |
do { | |
self.computePipelineState = try device.makeComputePipelineState(function: function) | |
} catch { | |
print("Error creating compute pipeline state: \(error)") | |
return nil | |
} | |
} | |
func generateSineWave() -> [Float] { | |
// Create buffers | |
let resultBuffer = device.makeBuffer(length: arraySize * MemoryLayout<Float>.size, options: .storageModeShared)! | |
var frequency = self.frequency | |
var sampleRate = self.sampleRate | |
// Create a command buffer and encoder | |
let commandBuffer = commandQueue.makeCommandBuffer()! | |
let computeEncoder = commandBuffer.makeComputeCommandEncoder()! | |
computeEncoder.setComputePipelineState(computePipelineState) | |
computeEncoder.setBuffer(resultBuffer, offset: 0, index: 0) | |
computeEncoder.setBytes(&frequency, length: MemoryLayout<Float>.size, index: 1) | |
computeEncoder.setBytes(&sampleRate, length: MemoryLayout<Float>.size, index: 2) | |
// Dispatch the compute shader | |
let threadGroupSize = MTLSize(width: 256, height: 1, depth: 1) | |
let threadGroups = MTLSize(width: (arraySize + threadGroupSize.width - 1) / threadGroupSize.width, height: 1, depth: 1) | |
computeEncoder.dispatchThreadgroups(threadGroups, threadsPerThreadgroup: threadGroupSize) | |
// End encoding and commit the command buffer | |
computeEncoder.endEncoding() | |
commandBuffer.commit() | |
commandBuffer.waitUntilCompleted() | |
// Retrieve the results | |
let resultPointer = resultBuffer.contents().bindMemory(to: Float.self, capacity: arraySize) | |
return Array(UnsafeBufferPointer(start: resultPointer, count: arraySize)) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment