-
-
Save schwa/7e3585db1d670d57f1a2a89adc453462 to your computer and use it in GitHub Desktop.
Basic example of metal function stitching
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Foundation | |
import Metal | |
// Example of Metal Function Stitching. | |
// Combine basic Metal Shading Language functions at runtime to create a new function. | |
// Mentioned very quickly in this WWDC session https://developer.apple.com/videos/play/wwdc2021/10229 | |
// and otherwise documented here https://developer.apple.com/documentation/metal/mtlfunctionstitchinggraph | |
// Create a device. | |
let device = MTLCreateSystemDefaultDevice()! | |
let nodeSource = """ | |
#include <metal_stdlib> | |
#include <metal_logging> | |
using namespace metal; | |
[[stitchable]] | |
int add(int a, int b) { | |
os_log_default.log("add() %d %d", a, b); | |
return a + b; | |
} | |
[[stitchable]] | |
int multiply(int a, int b) { | |
os_log_default.log("multiply() %d %d", a, b); | |
return a * b; | |
} | |
""" | |
// Compile the node library amd load 'add' and 'multiply' functions from it. | |
let options = MTLCompileOptions() | |
options.enableLogging = true | |
let nodeLibrary = try device.makeLibrary(source: nodeSource, options: options) | |
let add = nodeLibrary.makeFunction(name: "add")! | |
let multiply = nodeLibrary.makeFunction(name: "multiply")! | |
// Make a graph with some inputs... | |
let input0 = MTLFunctionStitchingInputNode(argumentIndex: 0) | |
let input1 = MTLFunctionStitchingInputNode(argumentIndex: 1) | |
let input2 = MTLFunctionStitchingInputNode(argumentIndex: 2) | |
let input3 = MTLFunctionStitchingInputNode(argumentIndex: 3) | |
// Some nodes that use these inputs and have names... | |
let node0 = MTLFunctionStitchingFunctionNode(name: "add", arguments: [input0, input1], controlDependencies: []) | |
let node1 = MTLFunctionStitchingFunctionNode(name: "add", arguments: [input0, input1], controlDependencies: []) | |
let node2 = MTLFunctionStitchingFunctionNode(name: "multiply", arguments: [node0, node1], controlDependencies: []) | |
// And a named graph that uses these nodes, and will produce a result using the output node | |
let graph = MTLFunctionStitchingGraph(functionName: "result", nodes: [node0, node1], outputNode: node2, attributes: []) | |
// Now we stitch the functions from earlier and the graph together into a library... | |
let stitchedLibraryDescriptor = MTLStitchedLibraryDescriptor() | |
stitchedLibraryDescriptor.functions = [add, multiply] | |
stitchedLibraryDescriptor.functionGraphs = [graph] | |
let stitchedLibrary = try device.makeLibrary(stitchedDescriptor: stitchedLibraryDescriptor) | |
// And extract the result function from it. | |
let resultFunction = stitchedLibrary.makeFunction(name: "result")! | |
// Our compute kernel shader that calls the result function (marked [[visible]] in the node source) | |
let kernelSource = """ | |
#include <metal_stdlib> | |
#include <metal_logging> | |
using namespace metal; | |
[[visible]] int result(int, int, int, int); | |
[[kernel]] | |
void kernel_main( | |
const device int &input0 [[buffer(1)]], | |
const device int &input1 [[buffer(2)]], | |
const device int &input2 [[buffer(3)]], | |
const device int &input3 [[buffer(4)]], | |
device uint *output [[buffer(5)]] | |
) { | |
os_log_default.log("kernel_main() Inputs: %d %d %d %d", input0, input1, input2, input3); | |
*output = result(input0, input1, input2, input3); | |
os_log_default.log("kernel_main() Result: %d", *output); | |
} | |
""" | |
let kernelLibrary = try device.makeLibrary(source: kernelSource, options: options) | |
let kernel = kernelLibrary.makeFunction(name: "kernel_main")! | |
// Now create a compute pipeline and tell it about the kernel function and the linked function we stitched together earlier. | |
let pipelineDescriptor = MTLComputePipelineDescriptor() | |
pipelineDescriptor.computeFunction = kernel | |
let linkedFunctions = MTLLinkedFunctions() | |
linkedFunctions.functions = [resultFunction] | |
pipelineDescriptor.linkedFunctions = linkedFunctions | |
let (pipeline, reflection) = try device.makeComputePipelineState(descriptor: pipelineDescriptor, options: .bindingInfo) | |
// We'll need a command queue to do work. | |
let commandQueue = device.makeCommandQueue()! | |
// We want to log from the shader, so we need a log state with a log handler. | |
let loggingDescriptor = MTLLogStateDescriptor() | |
loggingDescriptor.bufferSize = 64 * 1024 | |
loggingDescriptor.level = .info | |
let logState = try! device.makeLogState(descriptor: loggingDescriptor) | |
logState.addLogHandler { (subsystem, category, level, message) in | |
print("[GPU] \(message)") | |
} | |
// We'll also need a command buffer to encode work into. | |
let commandBufferDescriptor = MTLCommandBufferDescriptor() | |
commandBufferDescriptor.logState = logState | |
let commandBuffer = commandQueue.makeCommandBuffer(descriptor: commandBufferDescriptor)! | |
// And here's how we encode the work to be done. | |
let encoder = commandBuffer.makeComputeCommandEncoder()! | |
// Tell the encoder which pipeline to use (the one we made earlier). | |
encoder.setComputePipelineState(pipeline) | |
// Let's provide some input parameters. These will be passed into our stitched function and then the ndoes. | |
encoder.setUnsafeBytes(of: 1, at: reflection!.binding(for: "input0")) | |
encoder.setUnsafeBytes(of: 2, at: reflection!.binding(for: "input1")) | |
encoder.setUnsafeBytes(of: 3, at: reflection!.binding(for: "input2")) | |
encoder.setUnsafeBytes(of: 4, at: reflection!.binding(for: "input3")) | |
// Let's provide a buffer to store the output. We'll initialize it to zero and it'll be the size of an Int32 | |
var output: Int32 = 0 | |
let outputBuffer = device.makeBuffer(bytes: &output, length: MemoryLayout<Int32>.size, options: .storageModeShared)! | |
encoder.setBuffer(outputBuffer, offset: 0, index: reflection!.binding(for: "output")) | |
// Let's encode some work | |
encoder.dispatchThreads(MTLSize(width: 1, height: 1, depth: 1), threadsPerThreadgroup: MTLSize(width: 1, height: 1, depth: 1)) | |
// And that's it for encoding | |
encoder.endEncoding() | |
// Now we'll commit the command buffer and wait for it to complete. | |
commandBuffer.commit() | |
commandBuffer.waitUntilCompleted() | |
// Chsck the output buffer to see the result. | |
output = outputBuffer.contents().bindMemory(to: Int32.self, capacity: 1)[0] | |
print("[CPU] Result: \(output)") | |
// When run you should see this output: | |
// | |
// [GPU] kernel_main() Inputs: 1 2 3 4 | |
// [GPU] add() 1 2 | |
// [GPU] add() 1 2 | |
// [GPU] multiply() 3 3 | |
// [GPU] kernel_main() Result: 9 | |
// [CPU] Result: 9 | |
// Some helper functions to make the code more readable. | |
extension MTLComputeCommandEncoder { | |
func setUnsafeBytes<T>(of value: T, at index: Int) { | |
withUnsafeBytes(of: value) { | |
setBytes($0.baseAddress!, length: $0.count, index: index) | |
} | |
} | |
} | |
extension MTLComputePipelineReflection { | |
func binding(for name: String) -> Int { | |
return bindings.first(where: { $0.name == name })!.index | |
} | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment