Skip to content

Instantly share code, notes, and snippets.

@schwa
Last active November 25, 2024 09:09
Show Gist options
  • Save schwa/7e3585db1d670d57f1a2a89adc453462 to your computer and use it in GitHub Desktop.
Save schwa/7e3585db1d670d57f1a2a89adc453462 to your computer and use it in GitHub Desktop.
Basic example of metal function stitching
import Foundation
import Metal
// Example of Metal Function Stitching.
// Combine basic Metal Shading Language functions at runtime to create a new function.
// Mentioned very quickly in this WWDC session https://developer.apple.com/videos/play/wwdc2021/10229
// and otherwise documented here https://developer.apple.com/documentation/metal/mtlfunctionstitchinggraph
// Create a device.
let device = MTLCreateSystemDefaultDevice()!
let nodeSource = """
#include <metal_stdlib>
#include <metal_logging>
using namespace metal;
[[stitchable]]
int add(int a, int b) {
os_log_default.log("add() %d %d", a, b);
return a + b;
}
[[stitchable]]
int multiply(int a, int b) {
os_log_default.log("multiply() %d %d", a, b);
return a * b;
}
"""
// Compile the node library amd load 'add' and 'multiply' functions from it.
let options = MTLCompileOptions()
options.enableLogging = true
let nodeLibrary = try device.makeLibrary(source: nodeSource, options: options)
let add = nodeLibrary.makeFunction(name: "add")!
let multiply = nodeLibrary.makeFunction(name: "multiply")!
// Make a graph with some inputs...
let input0 = MTLFunctionStitchingInputNode(argumentIndex: 0)
let input1 = MTLFunctionStitchingInputNode(argumentIndex: 1)
let input2 = MTLFunctionStitchingInputNode(argumentIndex: 2)
let input3 = MTLFunctionStitchingInputNode(argumentIndex: 3)
// Some nodes that use these inputs and have names...
let node0 = MTLFunctionStitchingFunctionNode(name: "add", arguments: [input0, input1], controlDependencies: [])
let node1 = MTLFunctionStitchingFunctionNode(name: "add", arguments: [input0, input1], controlDependencies: [])
let node2 = MTLFunctionStitchingFunctionNode(name: "multiply", arguments: [node0, node1], controlDependencies: [])
// And a named graph that uses these nodes, and will produce a result using the output node
let graph = MTLFunctionStitchingGraph(functionName: "result", nodes: [node0, node1], outputNode: node2, attributes: [])
// Now we stitch the functions from earlier and the graph together into a library...
let stitchedLibraryDescriptor = MTLStitchedLibraryDescriptor()
stitchedLibraryDescriptor.functions = [add, multiply]
stitchedLibraryDescriptor.functionGraphs = [graph]
let stitchedLibrary = try device.makeLibrary(stitchedDescriptor: stitchedLibraryDescriptor)
// And extract the result function from it.
let resultFunction = stitchedLibrary.makeFunction(name: "result")!
// Our compute kernel shader that calls the result function (marked [[visible]] in the node source)
let kernelSource = """
#include <metal_stdlib>
#include <metal_logging>
using namespace metal;
[[visible]] int result(int, int, int, int);
[[kernel]]
void kernel_main(
const device int &input0 [[buffer(1)]],
const device int &input1 [[buffer(2)]],
const device int &input2 [[buffer(3)]],
const device int &input3 [[buffer(4)]],
device uint *output [[buffer(5)]]
) {
os_log_default.log("kernel_main() Inputs: %d %d %d %d", input0, input1, input2, input3);
*output = result(input0, input1, input2, input3);
os_log_default.log("kernel_main() Result: %d", *output);
}
"""
let kernelLibrary = try device.makeLibrary(source: kernelSource, options: options)
let kernel = kernelLibrary.makeFunction(name: "kernel_main")!
// Now create a compute pipeline and tell it about the kernel function and the linked function we stitched together earlier.
let pipelineDescriptor = MTLComputePipelineDescriptor()
pipelineDescriptor.computeFunction = kernel
let linkedFunctions = MTLLinkedFunctions()
linkedFunctions.functions = [resultFunction]
pipelineDescriptor.linkedFunctions = linkedFunctions
let (pipeline, reflection) = try device.makeComputePipelineState(descriptor: pipelineDescriptor, options: .bindingInfo)
// We'll need a command queue to do work.
let commandQueue = device.makeCommandQueue()!
// We want to log from the shader, so we need a log state with a log handler.
let loggingDescriptor = MTLLogStateDescriptor()
loggingDescriptor.bufferSize = 64 * 1024
loggingDescriptor.level = .info
let logState = try! device.makeLogState(descriptor: loggingDescriptor)
logState.addLogHandler { (subsystem, category, level, message) in
print("[GPU] \(message)")
}
// We'll also need a command buffer to encode work into.
let commandBufferDescriptor = MTLCommandBufferDescriptor()
commandBufferDescriptor.logState = logState
let commandBuffer = commandQueue.makeCommandBuffer(descriptor: commandBufferDescriptor)!
// And here's how we encode the work to be done.
let encoder = commandBuffer.makeComputeCommandEncoder()!
// Tell the encoder which pipeline to use (the one we made earlier).
encoder.setComputePipelineState(pipeline)
// Let's provide some input parameters. These will be passed into our stitched function and then the ndoes.
encoder.setUnsafeBytes(of: 1, at: reflection!.binding(for: "input0"))
encoder.setUnsafeBytes(of: 2, at: reflection!.binding(for: "input1"))
encoder.setUnsafeBytes(of: 3, at: reflection!.binding(for: "input2"))
encoder.setUnsafeBytes(of: 4, at: reflection!.binding(for: "input3"))
// Let's provide a buffer to store the output. We'll initialize it to zero and it'll be the size of an Int32
var output: Int32 = 0
let outputBuffer = device.makeBuffer(bytes: &output, length: MemoryLayout<Int32>.size, options: .storageModeShared)!
encoder.setBuffer(outputBuffer, offset: 0, index: reflection!.binding(for: "output"))
// Let's encode some work
encoder.dispatchThreads(MTLSize(width: 1, height: 1, depth: 1), threadsPerThreadgroup: MTLSize(width: 1, height: 1, depth: 1))
// And that's it for encoding
encoder.endEncoding()
// Now we'll commit the command buffer and wait for it to complete.
commandBuffer.commit()
commandBuffer.waitUntilCompleted()
// Chsck the output buffer to see the result.
output = outputBuffer.contents().bindMemory(to: Int32.self, capacity: 1)[0]
print("[CPU] Result: \(output)")
// When run you should see this output:
//
// [GPU] kernel_main() Inputs: 1 2 3 4
// [GPU] add() 1 2
// [GPU] add() 1 2
// [GPU] multiply() 3 3
// [GPU] kernel_main() Result: 9
// [CPU] Result: 9
// Some helper functions to make the code more readable.
extension MTLComputeCommandEncoder {
func setUnsafeBytes<T>(of value: T, at index: Int) {
withUnsafeBytes(of: value) {
setBytes($0.baseAddress!, length: $0.count, index: index)
}
}
}
extension MTLComputePipelineReflection {
func binding(for name: String) -> Int {
return bindings.first(where: { $0.name == name })!.index
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment