Last active
April 7, 2023 01:42
-
-
Save troughton/08f7c3c6e08b8bac9cbcc84fd5eb2d7c to your computer and use it in GitHub Desktop.
MetalStateCaches for concurrency debugging
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// | |
// Caches.swift | |
// MetalRenderer | |
// | |
// Created by Thomas Roughton on 24/12/17. | |
// | |
#if canImport(Metal) | |
@preconcurrency import Metal | |
import SubstrateUtilities | |
final actor MetalFunctionCache { | |
let library: MTLLibrary | |
private var functionCache = [FunctionDescriptor : MTLFunction]() | |
private var functionTasks = [FunctionDescriptor : Task<Void, Error>]() | |
init(library: MTLLibrary) { | |
self.library = library | |
} | |
private func setFunction(_ function: MTLFunction, for descriptor: FunctionDescriptor) { | |
self.functionCache[descriptor] = function | |
} | |
@_optimize(none) @inline(never) | |
nonisolated func createFunction(for functionDescriptor: FunctionDescriptor) async throws { | |
print("In functionTask for descriptor \(functionDescriptor.name)") | |
let function = try await library.makeFunction(name: functionDescriptor.name, constantValues: functionDescriptor.constants.map { MTLFunctionConstantValues($0) } ?? MTLFunctionConstantValues()) | |
await self.setFunction(function, for: functionDescriptor) | |
} | |
func function(for functionDescriptor: FunctionDescriptor) async -> MTLFunction? { | |
if let function = self.functionCache[functionDescriptor] { | |
return function | |
} | |
let functionTask: Task<Void, Error> | |
if let task = self.functionTasks[functionDescriptor] { | |
functionTask = task | |
} else { | |
let library = self.library | |
print("Creating functionTask for descriptor \(functionDescriptor.name)") | |
functionTask = Task.detached { | |
try await self.createFunction(for: functionDescriptor) | |
} | |
self.functionTasks[functionDescriptor] = functionTask | |
} | |
do { | |
_ = try await functionTask.value | |
return self.functionCache[functionDescriptor]! | |
} catch { | |
print("MetalRenderGraph: Error creating function named \(functionDescriptor.name)\(functionDescriptor.constants.map { " with constants \($0)" } ?? ""): \(error)") | |
return nil | |
} | |
} | |
} | |
final actor MetalRenderPipelineCache { | |
struct RenderPipelineFunctionNames : Hashable { | |
var vertexFunction : String | |
var fragmentFunction : String | |
} | |
let device: MTLDevice | |
let functionCache: MetalFunctionCache | |
private var renderStates = [MetalRenderPipelineDescriptor : MTLRenderPipelineState]() | |
private var renderReflection = [RenderPipelineDescriptor: MetalPipelineReflection]() | |
private var renderStateTasks = [MetalRenderPipelineDescriptor : Task<Void, Error>]() | |
init(device: MTLDevice, functionCache: MetalFunctionCache) { | |
self.device = device | |
self.functionCache = functionCache | |
} | |
private func setState(_ state: MTLRenderPipelineState, reflection: MetalPipelineReflection, for descriptor: MetalRenderPipelineDescriptor) { | |
self.renderStates[descriptor] = state | |
if self.renderReflection[descriptor.descriptor] == nil { | |
self.renderReflection[descriptor.descriptor] = reflection | |
} | |
} | |
@_optimize(none) @inline(never) | |
nonisolated func createRenderState(metalDescriptor: MetalRenderPipelineDescriptor) async throws { | |
print("In renderStateTask for descriptor \(metalDescriptor)") | |
guard let mtlDescriptor = await MTLRenderPipelineDescriptor(metalDescriptor, functionCache: self.functionCache) else { | |
return | |
} | |
var reflection: MTLAutoreleasedRenderPipelineReflection? = nil | |
let state = try self.device.makeRenderPipelineState(descriptor: mtlDescriptor, options: [.bufferTypeInfo], reflection: &reflection) | |
// TODO: can we retrieve the thread execution width for render pipelines? | |
let pipelineReflection = MetalPipelineReflection(threadExecutionWidth: 4, vertexFunction: mtlDescriptor.vertexFunction!, fragmentFunction: mtlDescriptor.fragmentFunction, renderState: state, renderReflection: reflection!) | |
await self.setState(state, reflection: pipelineReflection, for: metalDescriptor) | |
} | |
public func state(descriptor: RenderPipelineDescriptor, renderTarget renderTarget: RenderTargetDescriptor) async -> MTLRenderPipelineState? { | |
let metalDescriptor = MetalRenderPipelineDescriptor(descriptor, renderTargetDescriptor: renderTarget) | |
if let state = self.renderStates[metalDescriptor] { | |
return state | |
} | |
let renderStateTask: Task<Void, Error> | |
if let task = self.renderStateTasks[metalDescriptor] { | |
renderStateTask = task | |
} else { | |
print("Creating renderStateTask for descriptor \(metalDescriptor)") | |
renderStateTask = Task.detached { | |
try await self.createRenderState(metalDescriptor: metalDescriptor) | |
} | |
self.renderStateTasks[metalDescriptor] = renderStateTask | |
} | |
do { | |
_ = try await renderStateTask.value | |
return self.renderStates[metalDescriptor] | |
} catch { | |
print("MetalRenderGraph: Error creating render pipeline state for descriptor \(descriptor): \(error)") | |
return nil | |
} | |
} | |
public func reflection(descriptor pipelineDescriptor: RenderPipelineDescriptor, renderTarget: RenderTargetDescriptor) async -> MetalPipelineReflection? { | |
if let reflection = self.renderReflection[pipelineDescriptor] { | |
return reflection | |
} | |
guard await self.state(descriptor: pipelineDescriptor, renderTarget: renderTarget) != nil else { | |
return nil | |
} | |
return await self.reflection(descriptor: pipelineDescriptor, renderTarget: renderTarget) | |
} | |
} | |
final actor MetalComputePipelineCache { | |
struct CacheKey: Hashable { | |
var descriptor: ComputePipelineDescriptor | |
var threadgroupSizeIsMultipleOfThreadExecutionWidth: Bool | |
} | |
let device: MTLDevice | |
let functionCache: MetalFunctionCache | |
private var computeStates = [CacheKey: MTLComputePipelineState]() | |
private var computeReflection = [ComputePipelineDescriptor: MetalPipelineReflection]() | |
private var computeStateTasks = [CacheKey : Task<Void, Error>]() | |
init(device: MTLDevice, functionCache: MetalFunctionCache) { | |
self.device = device | |
self.functionCache = functionCache | |
} | |
private func setState(_ state: MTLComputePipelineState, reflection: MetalPipelineReflection, for cacheKey: CacheKey) { | |
self.computeStates[cacheKey] = state | |
if self.computeReflection[cacheKey.descriptor] == nil { | |
self.computeReflection[cacheKey.descriptor] = reflection | |
} | |
} | |
@_optimize(none) @inline(never) | |
nonisolated func createComputeState(cacheKey: CacheKey) async throws { | |
let descriptor = cacheKey.descriptor | |
guard let function = await self.functionCache.function(for: descriptor.function) else { | |
return | |
} | |
let mtlDescriptor = MTLComputePipelineDescriptor() | |
mtlDescriptor.computeFunction = function | |
mtlDescriptor.threadGroupSizeIsMultipleOfThreadExecutionWidth = cacheKey.threadgroupSizeIsMultipleOfThreadExecutionWidth | |
if !descriptor.linkedFunctions.isEmpty, #available(iOS 14.0, macOS 11.0, *) { | |
var functions = [MTLFunction]() | |
for functionDescriptor in descriptor.linkedFunctions { | |
if let function = await self.functionCache.function(for: functionDescriptor) { | |
functions.append(function) | |
} | |
} | |
let linkedFunctions = MTLLinkedFunctions() | |
linkedFunctions.functions = functions | |
mtlDescriptor.linkedFunctions = linkedFunctions | |
} | |
var reflection: MTLAutoreleasedComputePipelineReflection? = nil | |
let state = try self.device.makeComputePipelineState(descriptor: mtlDescriptor, options: [.bufferTypeInfo], reflection: &reflection) | |
let pipelineReflection = MetalPipelineReflection(threadExecutionWidth: state.threadExecutionWidth, function: mtlDescriptor.computeFunction!, computeState: state, computeReflection: reflection!) | |
await self.setState(state, reflection: pipelineReflection, for: cacheKey) | |
} | |
public func state(descriptor: ComputePipelineDescriptor, threadgroupSizeIsMultipleOfThreadExecutionWidth: Bool = false) async -> MTLComputePipelineState? { | |
// Figure out whether the thread group size is always a multiple of the thread execution width and set the optimisation hint appropriately. | |
let cacheKey = CacheKey(descriptor: descriptor, threadgroupSizeIsMultipleOfThreadExecutionWidth: threadgroupSizeIsMultipleOfThreadExecutionWidth) | |
if let state = self.computeStates[cacheKey] { | |
return state | |
} | |
let computeStateTask: Task<Void, Error> | |
if let task = self.computeStateTasks[cacheKey] { | |
computeStateTask = task | |
} else { | |
computeStateTask = Task.detached { | |
try await self.createComputeState(cacheKey: cacheKey) | |
} | |
self.computeStateTasks[cacheKey] = computeStateTask | |
} | |
do { | |
_ = try await computeStateTask.value | |
return self.computeStates[cacheKey] | |
} catch { | |
print("MetalRenderGraph: Error creating compute pipeline state for descriptor \(descriptor): \(error)") | |
return nil | |
} | |
} | |
public func reflection(for pipelineDescriptor: ComputePipelineDescriptor) async -> MetalPipelineReflection? { | |
if let reflection = self.computeReflection[pipelineDescriptor] { | |
return reflection | |
} | |
guard await self.state(descriptor: pipelineDescriptor) != nil else { | |
return nil | |
} | |
return await self.reflection(for: pipelineDescriptor) | |
} | |
} | |
final actor MetalDepthStencilStateCache { | |
struct RenderPipelineFunctionNames : Hashable { | |
var vertexFunction : String | |
var fragmentFunction : String | |
} | |
let device: MTLDevice | |
let defaultDepthState : MTLDepthStencilState | |
private var depthStates: [(DepthStencilDescriptor, MTLDepthStencilState)] | |
init(device: MTLDevice) { | |
self.device = device | |
let defaultDepthDescriptor = DepthStencilDescriptor() | |
let mtlDescriptor = MTLDepthStencilDescriptor(defaultDepthDescriptor) | |
let defaultDepthState = self.device.makeDepthStencilState(descriptor: mtlDescriptor)! | |
self.defaultDepthState = defaultDepthState | |
self.depthStates = [(defaultDepthDescriptor, defaultDepthState)] | |
} | |
public subscript(descriptor: DepthStencilDescriptor) -> MTLDepthStencilState { | |
if let (_, state) = self.depthStates.first(where: { $0.0 == descriptor }) { | |
return state | |
} | |
let mtlDescriptor = MTLDepthStencilDescriptor(descriptor) | |
let state = self.device.makeDepthStencilState(descriptor: mtlDescriptor)! | |
self.depthStates.append((descriptor, state)) | |
return state | |
} | |
} | |
final class MetalArgumentEncoderCache { | |
let device: MTLDevice | |
private var cache: [ArgumentBufferDescriptor: MTLArgumentEncoder] | |
let lock = SpinLock() | |
init(device: MTLDevice) { | |
self.device = device | |
self.cache = [:] | |
} | |
deinit { | |
self.lock.deinit() | |
} | |
public subscript(descriptor: ArgumentBufferDescriptor) -> MTLArgumentEncoder { | |
return self.lock.withLock { | |
if let encoder = self.cache[descriptor] { | |
return encoder | |
} | |
let arguments = descriptor.argumentDescriptors | |
let encoder = self.device.makeArgumentEncoder(arguments: arguments) | |
self.cache[descriptor] = encoder | |
return encoder! | |
} | |
} | |
} | |
final class MetalStateCaches { | |
let device : MTLDevice | |
var libraryURL : URL? | |
var loadedLibraryModificationDate : Date = .distantPast | |
var functionCache: MetalFunctionCache | |
var renderPipelineCache: MetalRenderPipelineCache | |
var computePipelineCache: MetalComputePipelineCache | |
let depthStencilCache: MetalDepthStencilStateCache | |
let argumentEncoderCache: MetalArgumentEncoderCache | |
public init(device: MTLDevice, libraryPath: String?) { | |
self.device = device | |
let library: MTLLibrary | |
if let libraryPath = libraryPath { | |
var libraryURL = URL(fileURLWithPath: libraryPath) | |
var isDirectory: ObjCBool = false | |
if FileManager.default.fileExists(atPath: libraryPath, isDirectory: &isDirectory), isDirectory.boolValue { | |
#if os(macOS) || targetEnvironment(macCatalyst) | |
libraryURL = libraryURL.appendingPathComponent(device.isAppleSiliconGPU ? "Library-macOSAppleSilicon.metallib" : "Library-macOS.metallib") | |
#else | |
libraryURL = libraryURL.appendingPathComponent("Library-iOS.metallib") | |
#endif | |
} | |
library = try! device.makeLibrary(filepath: libraryURL.path) | |
self.libraryURL = libraryURL | |
self.loadedLibraryModificationDate = try! libraryURL.resourceValues(forKeys: [.contentModificationDateKey]).contentModificationDate! | |
} else { | |
library = device.makeDefaultLibrary()! | |
} | |
self.functionCache = MetalFunctionCache(library: library) | |
self.renderPipelineCache = .init(device: self.device, functionCache: functionCache) | |
self.computePipelineCache = .init(device: self.device, functionCache: functionCache) | |
self.depthStencilCache = .init(device: self.device) | |
self.argumentEncoderCache = .init(device: self.device) | |
} | |
func checkForLibraryReload() async { | |
self.libraryURL?.removeCachedResourceValue(forKey: .contentModificationDateKey) | |
guard let libraryURL = self.libraryURL else { return } | |
if FileManager.default.fileExists(atPath: libraryURL.path), | |
let currentModificationDate = try? libraryURL.resourceValues(forKeys: [.contentModificationDateKey]).contentModificationDate, | |
currentModificationDate > self.loadedLibraryModificationDate { | |
for queue in QueueRegistry.allQueues { | |
await queue.waitForCommandCompletion(queue.lastSubmittedCommand) // Wait for all commands to finish on all queues. | |
} | |
// Metal won't pick up the changes if we use the makeLibrary(filePath:) initialiser, | |
// so we have to load into a DispatchData instead. | |
guard let data = try? Data(contentsOf: libraryURL) else { | |
return | |
} | |
guard let library = data.withUnsafeBytes({ bytes -> MTLLibrary? in | |
let dispatchData = DispatchData(bytes: bytes) | |
return try? device.makeLibrary(data: dispatchData as __DispatchData) | |
}) else { return } | |
self.functionCache = MetalFunctionCache(library: library) | |
self.renderPipelineCache = .init(device: self.device, functionCache: functionCache) | |
self.computePipelineCache = .init(device: self.device, functionCache: functionCache) | |
VisibleFunctionTableRegistry.instance.markAllAsUninitialised() | |
IntersectionFunctionTableRegistry.instance.markAllAsUninitialised() | |
self.loadedLibraryModificationDate = currentModificationDate | |
} | |
} | |
public func renderPipelineReflection(descriptor pipelineDescriptor: RenderPipelineDescriptor, renderTarget: RenderTargetDescriptor) async -> MetalPipelineReflection? { | |
return await self.renderPipelineCache.reflection(descriptor: pipelineDescriptor, renderTarget: renderTarget) | |
} | |
public func computePipelineReflection(descriptor: ComputePipelineDescriptor) async -> MetalPipelineReflection? { | |
return await self.computePipelineCache.reflection(for: descriptor) | |
} | |
} | |
#endif // canImport(Metal) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment