Skip to content

Instantly share code, notes, and snippets.

@philipturner
Last active October 16, 2025 12:48
Show Gist options
  • Select an option

  • Save philipturner/af23dacd700601f645590bf7d83136ad to your computer and use it in GitHub Desktop.

Select an option

Save philipturner/af23dacd700601f645590bf7d83136ad to your computer and use it in GitHub Desktop.
Inspect rebuild
extension Application {
// TODO: Before finishing the acceleration structure PR, remove the public
// modifier for this.
public func updateBVH(inFlightFrameID: Int) {
let transaction = atoms.registerChanges()
device.commandQueue.withCommandList { commandList in
// Bind the descriptor heap.
#if os(Windows)
commandList.setDescriptorHeap(descriptorHeap)
#endif
bvhBuilder.purgeResources(
commandList: commandList)
bvhBuilder.setupGeneralCounters(
commandList: commandList)
bvhBuilder.upload(
transaction: transaction,
commandList: commandList,
inFlightFrameID: inFlightFrameID)
// Encode the remove process.
bvhBuilder.removeProcess1(
commandList: commandList,
inFlightFrameID: inFlightFrameID)
bvhBuilder.removeProcess2(
commandList: commandList)
bvhBuilder.removeProcess3(
commandList: commandList)
bvhBuilder.removeProcess4(
commandList: commandList)
// Encode the add process.
bvhBuilder.addProcess1(
commandList: commandList,
inFlightFrameID: inFlightFrameID)
bvhBuilder.addProcess2(
commandList: commandList)
bvhBuilder.addProcess3(
commandList: commandList,
inFlightFrameID: inFlightFrameID)
// Encode the rebuild process.
bvhBuilder.rebuildProcess1(
commandList: commandList)
bvhBuilder.rebuildProcess2(
commandList: commandList)
bvhBuilder.counters.crashBuffer.download(
commandList: commandList,
inFlightFrameID: inFlightFrameID)
}
}
// TODO: Before finishing the acceleration structure PR, remove the public
// modifier for this.
public func forgetIdleState(inFlightFrameID: Int) {
device.commandQueue.withCommandList { commandList in
// Bind the descriptor heap.
#if os(Windows)
commandList.setDescriptorHeap(descriptorHeap)
#endif
bvhBuilder.resetMotionVectors(
commandList: commandList,
inFlightFrameID: inFlightFrameID)
bvhBuilder.resetVoxelMarks(
commandList: commandList)
#if os(Windows)
bvhBuilder.computeUAVBarrier(commandList: commandList)
#endif
}
// Delete the transactionArgs state variable.
bvhBuilder.transactionArgs = nil
}
}
// TODO: Before finishing the acceleration structure PR, remove these debugging
// utilities from the code base.
extension Application {
// Circumvent a flaky crash by holding a reference to the buffer while the
// command list executes. Do not abuse this by calling any of the 'Debug'
// functions more than once in a single program execution.
nonisolated(unsafe)
private static var downloadBuffers: [Buffer] = []
public func downloadGeneralCounters() -> [UInt32] {
func copySourceBuffer() -> Buffer {
bvhBuilder.counters.general
}
var output = [UInt32](repeating: .zero, count: 10)
downloadDebugOutput(
&output, copySourceBuffer: copySourceBuffer())
return output
}
public func downloadAssignedSlotIDs() -> [UInt32] {
func copySourceBuffer() -> Buffer {
bvhBuilder.voxels.dense.assignedSlotIDs
}
var output = [UInt32](repeating: .zero, count: 4096)
downloadDebugOutput(
&output, copySourceBuffer: copySourceBuffer())
return output
}
public func downloadMemorySlots() -> [UInt32] {
func copySourceBuffer() -> Buffer {
bvhBuilder.voxels.sparse.memorySlots
}
var arraySize = bvhBuilder.voxels.memorySlotCount
arraySize *= MemorySlot.totalSize
arraySize /= 4
var output = [UInt32](repeating: .zero, count: arraySize)
downloadDebugOutput(
&output, copySourceBuffer: copySourceBuffer())
return output
}
public func downloadRebuiltVoxelCoords() -> [UInt32] {
func copySourceBuffer() -> Buffer {
bvhBuilder.voxels.sparse.rebuiltVoxelCoords
}
let arraySize = bvhBuilder.voxels.memorySlotCount
var output = [UInt32](repeating: .zero, count: arraySize)
downloadDebugOutput(
&output, copySourceBuffer: copySourceBuffer())
return output
}
private func downloadDebugOutput<T>(
_ outputData: inout [T],
copySourceBuffer: Buffer
) {
#if os(macOS)
let outputBuffer = copySourceBuffer
#else
let nativeBuffer = copySourceBuffer
var bufferDesc = BufferDescriptor()
bufferDesc.device = device
bufferDesc.size = nativeBuffer.size
bufferDesc.type = .output
let outputBuffer = Buffer(descriptor: bufferDesc)
#endif
Self.downloadBuffers.append(outputBuffer)
#if os(Windows)
device.commandQueue.withCommandList { commandList in
commandList.download(
nativeBuffer: nativeBuffer,
outputBuffer: outputBuffer)
}
#endif
device.commandQueue.flush()
outputData.withUnsafeMutableBytes { bufferPointer in
outputBuffer.read(output: bufferPointer)
}
}
}
import HDL
import MolecularRenderer
// Current task:
// - Test for correct functionality during rebuild.
// - Less complex than the previous test; quite easy and quick.
// - Out of scope for the previous test, does not need to be cross-coupled
// with the various possibilities for behavior during add/remove.
// - Will still rely on the same silicon carbide lattice as the previous test.
// - Archive the above testing code to a GitHub gist, along with its utilities
// in "Application+UpdateBVH.swift".
// Helpful facts about the test setup:
// atom count: 8631
// memory slot count: 3616
// memory slot size: 55304 B
// .headerLarge = 0 B
// .headerSmall = 8 B
// .referenceLarge = 2056 B
// .referenceSmall = 14344 B
// voxel group count: 64
// voxel count: 4096
@MainActor
func createApplication() -> Application {
// Set up the device.
var deviceDesc = DeviceDescriptor()
deviceDesc.deviceID = Device.fastestDeviceID
let device = Device(descriptor: deviceDesc)
// Set up the display.
var displayDesc = DisplayDescriptor()
displayDesc.device = device
displayDesc.frameBufferSize = SIMD2<Int>(1080, 1080)
displayDesc.monitorID = device.fastestMonitorID
let display = Display(descriptor: displayDesc)
// Set up the application.
var applicationDesc = ApplicationDescriptor()
applicationDesc.device = device
applicationDesc.display = display
applicationDesc.upscaleFactor = 1
applicationDesc.addressSpaceSize = 2_000_000
applicationDesc.voxelAllocationSize = 200_000_000
applicationDesc.worldDimension = 32
let application = Application(descriptor: applicationDesc)
return application
}
let application = createApplication()
let lattice = Lattice<Cubic> { h, k, l in
Bounds { 10 * (h + k + l) }
Material { .checkerboard(.silicon, .carbon) }
}
#if false
application.run {
for atomID in lattice.atoms.indices {
let atom = lattice.atoms[atomID]
application.atoms[atomID] = atom
}
let image = application.render()
application.present(image: image)
}
#else
func pad<T: BinaryInteger>(_ integer: T) -> String {
var output = "\(integer)"
while output.count < 4 {
output = " " + output
}
return output
}
@MainActor
func analyzeGeneralCounters() {
let output = application.downloadGeneralCounters()
print("atoms removed voxel count:", output[0])
guard output[1] == 1,
output[2] == 1 else {
fatalError("Indirect dispatch arguments were malformatted.")
}
print("vacant slot count:", output[4])
print("allocated slot count:", output[5])
print("rebuilt voxel count:", output[6])
guard output[7] == 1,
output[8] == 1 else {
fatalError("Indirect dispatch arguments were malformatted.")
}
}
@MainActor
func inspectRebuiltVoxels() {
let voxelCoords = application.downloadRebuiltVoxelCoords()
for i in voxelCoords.indices {
let encoded = voxelCoords[i]
guard encoded != UInt32.max else {
continue
}
let decoded = SIMD3<UInt32>(
encoded & 1023,
(encoded >> 10) & 1023,
encoded >> 20
)
let lowerCorner = SIMD3<Float>(decoded) * 2 - (Float(32) / 2)
print(pad(i), lowerCorner)
}
}
@MainActor
func inspectMemorySlots() {
let assignedSlotIDs = application.downloadAssignedSlotIDs()
let memorySlots = application.downloadMemorySlots()
var atomDuplicatedReferences = [Int](repeating: .zero, count: 8631)
for i in assignedSlotIDs.indices {
let assignedSlotID = assignedSlotIDs[i]
guard assignedSlotID != UInt32.max else {
continue
}
let headerAddress = Int(assignedSlotID) * 55304 / 4
let atomCount = memorySlots[headerAddress]
print(pad(i), pad(assignedSlotID), pad(atomCount), terminator: " ")
let listAddress = headerAddress + 2056 / 4
for j in 0..<Int(atomCount) {
let atomID = memorySlots[listAddress + j]
if j < 12 {
print(pad(atomID), terminator: " ")
}
if atomID >= atomDuplicatedReferences.count {
fatalError("Invalid atom ID: \(atomID)")
}
atomDuplicatedReferences[Int(atomID)] += 1
}
print()
}
var summary = [Int](repeating: .zero, count: 9)
for atomID in atomDuplicatedReferences.indices {
let referenceCount = atomDuplicatedReferences[atomID]
if referenceCount > 8 {
fatalError("Invalid reference count: \(referenceCount)")
}
summary[referenceCount] += 1
}
print()
for referenceCount in summary.indices {
let atomCount = summary[referenceCount]
print("\(pad(referenceCount)): \(pad(atomCount))")
}
print("total atom count: \(summary[1...].reduce(0, +))")
print("total reference count: \(atomDuplicatedReferences.reduce(0, +))")
}
@MainActor
func inspectSmallReferences() {
let assignedSlotIDs = application.downloadAssignedSlotIDs()
let memorySlots = application.downloadMemorySlots()
for i in assignedSlotIDs.indices {
let assignedSlotID = assignedSlotIDs[i]
guard assignedSlotID != UInt32.max else {
continue
}
guard i == 2457 else {
continue
}
let headerAddress = Int(assignedSlotID) * 55304 / 4
let atomCount = memorySlots[headerAddress]
guard atomCount == 1165 else {
fatalError("Got unexpected atom count: \(atomCount)")
}
let smallRefCount = memorySlots[headerAddress + 1]
print("large references allocated:", atomCount)
print("small references allocated:", smallRefCount)
let largeRefAddress = headerAddress + 2056 / 4
let smallRefAddress = headerAddress + 14344 / 4
var smallReferences: [UInt16] = []
for i in 0..<((smallRefCount + 1) / 2) {
let value32 = memorySlots[smallRefAddress + Int(i)]
let casted = unsafeBitCast(value32, to: SIMD2<UInt16>.self)
smallReferences.append(casted[0])
smallReferences.append(casted[1])
}
var atomDuplicatedReferences = [Int](
repeating: .zero, count: Int(atomCount))
let smallHeaderBase = headerAddress + 8 / 4
for voxelID in 0..<512 {
let header = memorySlots[smallHeaderBase + voxelID]
guard header != UInt32.zero else {
continue
}
let headerCasted = unsafeBitCast(header, to: SIMD2<UInt16>.self)
let offsetStart = headerCasted[0]
let offsetEnd = headerCasted[1]
let refCount = offsetEnd - offsetStart
print(pad(voxelID), pad(offsetStart), pad(refCount), terminator: " ")
let listAddress = headerAddress + 2056 / 4
for referenceID in offsetStart..<offsetEnd {
let smallAtomID = smallReferences[Int(referenceID)]
let largeAtomID = memorySlots[listAddress + Int(smallAtomID)]
if referenceID < offsetStart + 12 {
print(pad(largeAtomID), terminator: " ")
}
if smallAtomID >= atomCount {
fatalError("Invalid small atom ID: \(smallAtomID)")
}
atomDuplicatedReferences[Int(smallAtomID)] += 1
}
print()
}
var summary = [Int](repeating: .zero, count: 28)
for atomID in atomDuplicatedReferences.indices {
let referenceCount = atomDuplicatedReferences[atomID]
if referenceCount > 27 {
fatalError("Invalid reference count: \(referenceCount)")
}
summary[referenceCount] += 1
}
print()
for referenceCount in summary.indices {
let atomCount = summary[referenceCount]
print("\(pad(referenceCount)): \(pad(atomCount))")
}
print("total atom count: \(summary[1...].reduce(0, +))")
print("total reference count: \(atomDuplicatedReferences.reduce(0, +))")
}
}
for frameID in 0...0 {
for atomID in lattice.atoms.indices {
let atom = lattice.atoms[atomID]
application.atoms[atomID] = atom
}
application.updateBVH(inFlightFrameID: frameID % 3)
application.forgetIdleState(inFlightFrameID: frameID % 3)
print()
print("===============")
print("=== frame \(frameID) ===")
print("===============")
print()
analyzeGeneralCounters()
print()
inspectSmallReferences()
}
#endif
extension RebuildProcess {
// [numthreads(128, 1, 1)]
// dispatch indirect groups SIMD3(atomic counter, 1, 1)
// threadgroup memory 2068 B
//
// # Phase I
//
// loop over the cuboid bounding box of each atom
// atomically accumulate into threadgroupCounters
//
// # Phase II
//
// read 4 voxels per thread, on 128 threads in parallel
// prefix sum over 512 small voxels (SIMD + group reduction)
// save the prefix sum result for Phase IV
// if reference count is too large, crash w/ diagnostic info
// write reference count into memory slot header
//
// # Phase III
//
// loop over a 3x3x3 grid of small voxels for each atom
// run the cube-sphere test and mask out voxels outside the 2 nm bound
// atomically accumulate into threadgroupCounters
// write a 16-bit reference to sparse.memorySlots
//
// # Phase IV
//
// restore the prefix sum result
// read end of reference list from threadgroupCounters
// if atom count is zero, output UInt32(0)
// otherwise
// store two offsets relative to the slot's region for 16-bit references
// compress these two 16-bit offsets into a 32-bit word
static func createSource2(worldDimension: Float) -> String {
// atoms.atoms
// voxels.dense.assignedSlotIDs
// voxels.sparse.rebuiltVoxelCoords
// voxels.sparse.memorySlots [32, 16]
func functionSignature() -> String {
#if os(macOS)
"""
kernel void rebuildProcess2(
\(CrashBuffer.functionArguments),
device float4 *atoms [[buffer(1)]],
device uint *assignedSlotIDs [[buffer(2)]],
device uint *rebuiltVoxelCoords [[buffer(3)]],
device uint *memorySlots32 [[buffer(4)]],
device ushort *memorySlots16 [[buffer(5)]],
uint groupID [[threadgroup_position_in_grid]],
uint localID [[thread_position_in_threadgroup]])
"""
#else
"""
\(CrashBuffer.functionArguments)
RWStructuredBuffer<float4> atoms : register(u1);
RWStructuredBuffer<uint> assignedSlotIDs : register(u2);
RWStructuredBuffer<uint> rebuiltVoxelCoords : register(u3);
RWStructuredBuffer<uint> memorySlots32 : register(u4);
RWBuffer<uint> memorySlots16 : register(u5);
groupshared uint threadgroupMemory[517];
[numthreads(128, 1, 1)]
[RootSignature(
\(CrashBuffer.rootSignatureArguments)
"UAV(u1),"
"UAV(u2),"
"UAV(u3),"
"UAV(u4),"
"DescriptorTable(UAV(u5, numDescriptors = 1)),"
)]
void rebuildProcess2(
uint groupID : SV_GroupID,
uint localID : SV_GroupThreadID)
"""
#endif
}
func allocateThreadgroupMemory() -> String {
#if os(macOS)
"threadgroup uint threadgroupMemory[517];"
#else
""
#endif
}
// Better memory locality in the Z axis for ray tracing.
func threadgroupAddress(_ i: String) -> String {
"256 * (localID / 64) + (\(i) * 64) + (localID % 64)"
}
func atomicFetchAdd() -> String {
#if os(macOS)
let buffer = "(threadgroup atomic_uint*)threadgroupMemory"
#else
let buffer = "threadgroupMemory"
#endif
return Reduction.atomicFetchAdd(
buffer: buffer,
address: "uint(address)",
operand: "1",
output: "offset")
}
func castUShort(_ input: String) -> String {
#if os(macOS)
"ushort(\(input))"
#else
input
#endif
}
return """
\(Shader.importStandardLibrary)
\(cubeSphereTest())
\(functionSignature())
{
\(allocateThreadgroupMemory())
if (crashBuffer[0] != 1) {
return;
}
uint encodedVoxelCoords = rebuiltVoxelCoords[groupID];
uint3 voxelCoords = \(VoxelResources.decode("encodedVoxelCoords"));
uint voxelID =
\(VoxelResources.generate("voxelCoords", worldDimension / 2));
float3 lowerCorner = float3(voxelCoords) * 2;
lowerCorner -= float(\(worldDimension / 2));
uint assignedSlotID = assignedSlotIDs[voxelID];
uint headerAddress = assignedSlotID * \(MemorySlot.totalSize / 4);
uint listAddress = headerAddress;
listAddress += \(MemorySlot.offset(.referenceLarge) / 4);
uint atomCount = memorySlots32[headerAddress];
\(Shader.unroll)
for (uint i = 0; i < 4; ++i) {
uint address = \(threadgroupAddress("i"));
threadgroupMemory[address] = 0;
}
\(Reduction.groupLocalBarrier())
// =======================================================================
// === Phase I ===
// =======================================================================
for (uint i = localID; i < atomCount; i += 128) {
uint atomID = memorySlots32[listAddress + i];
float4 atom = atoms[atomID];
\(computeLoopBounds())
// Iterate over the footprint on the 3D grid.
for (float z = boxMin[2]; z < boxMax[2]; ++z) {
for (float y = boxMin[1]; y < boxMax[1]; ++y) {
for (float x = boxMin[0]; x < boxMax[0]; ++x) {
float3 xyz = float3(x, y, z);
float address = \(VoxelResources.generate("xyz", 8));
uint offset;
\(atomicFetchAdd())
}
}
}
}
\(Reduction.groupLocalBarrier())
// =======================================================================
// === Phase II ===
// =======================================================================
uint countersSum = 0;
uint4 counters = 0;
\(Shader.unroll)
for (uint i = 0; i < 4; ++i) {
uint address = \(threadgroupAddress("i"));
uint temp = threadgroupMemory[address];
counters[i] = countersSum;
countersSum += temp;
}
\(Reduction.groupLocalBarrier())
uint wavePrefixSum = \(Reduction.wavePrefixSum("countersSum"));
uint waveInclusiveSum = wavePrefixSum + countersSum;
uint waveTotalSum =
\(Reduction.waveReadLaneAt("waveInclusiveSum", laneID: 31));
threadgroupMemory[512 + (localID / 32)] = waveTotalSum;
\(Reduction.groupLocalBarrier())
\(Reduction.threadgroupSumPrimitive(offset: 512))
// Incorporate all contributions to the prefix sum.
counters += wavePrefixSum;
counters += threadgroupMemory[512 + (localID / 32)];
\(Shader.unroll)
for (uint i = 0; i < 4; ++i) {
uint address = \(threadgroupAddress("i"));
threadgroupMemory[address] = counters[i];
}
\(Reduction.groupLocalBarrier())
uint referenceCount = threadgroupMemory[516];
if (referenceCount > 20480) {
if (localID == 0) {
bool acquiredLock = false;
\(CrashBuffer.acquireLock(errorCode: 4))
if (acquiredLock) {
crashBuffer[1] = voxelCoords.x;
crashBuffer[2] = voxelCoords.y;
crashBuffer[3] = voxelCoords.z;
crashBuffer[4] = referenceCount;
}
}
return;
}
if (localID == 0) {
memorySlots32[headerAddress + 1] = referenceCount;
}
// =======================================================================
// === Phase III ===
// =======================================================================
uint listAddress16 = headerAddress * 2;
listAddress16 += \(MemorySlot.offset(.referenceSmall) / 2);
for (uint i = localID; i < atomCount; i += 128) {
uint atomID = memorySlots32[listAddress + i];
float4 atom = atoms[atomID];
\(computeLoopBounds())
#if 1
// Iterate over the footprint on the 3D grid.
\(Shader.loop)
for (float z = 0; z < 3; ++z) {
\(Shader.unroll)
for (float y = 0; y < 3; ++y) {
\(Shader.unroll)
for (float x = 0; x < 3; ++x) {
float3 xyz = boxMin + float3(x, y, z);
// Narrow down the cells with a cube-sphere intersection test.
bool intersected = cubeSphereTest(xyz, atom);
if (intersected && all(xyz < boxMax)) {
float address = \(VoxelResources.generate("xyz", 8));
uint offset;
\(atomicFetchAdd())
memorySlots16[listAddress16 + offset] = \(castUShort("i"));
}
}
}
}
#else
// Iterate over the footprint on the 3D grid.
for (float z = boxMin[2]; z < boxMax[2]; ++z) {
for (float y = boxMin[1]; y < boxMax[1]; ++y) {
for (float x = boxMin[0]; x < boxMax[0]; ++x) {
float3 xyz = float3(x, y, z);
float address = \(VoxelResources.generate("xyz", 8));
uint offset;
\(atomicFetchAdd())
memorySlots16[listAddress16 + offset] = \(castUShort("i"));
}
}
}
#endif
}
\(Reduction.groupLocalBarrier())
// =======================================================================
// === Phase IV ===
// =======================================================================
uint smallHeaderBase = headerAddress;
smallHeaderBase += \(MemorySlot.offset(.headerSmall) / 4);
\(Shader.unroll)
for (uint i = 0; i < 4; ++i) {
uint address = \(threadgroupAddress("i"));
uint counterAfter = threadgroupMemory[address];
uint counterBefore = counters[i];
uint headerValue = 0;
if (counterAfter > counterBefore) {
headerValue = counterBefore | (counterAfter << 16);
}
memorySlots32[smallHeaderBase + address] = headerValue;
}
}
"""
}
}
extension BVHBuilder {
func rebuildProcess2(commandList: CommandList) {
commandList.withPipelineState(shaders.rebuild.process2) {
counters.crashBuffer.setBufferBindings(
commandList: commandList)
commandList.setBuffer(
atoms.atoms, index: 1)
commandList.setBuffer(
voxels.dense.assignedSlotIDs, index: 2)
commandList.setBuffer(
voxels.sparse.rebuiltVoxelCoords, index: 3)
commandList.setBuffer(
voxels.sparse.memorySlots, index: 4)
#if os(macOS)
commandList.setBuffer(
voxels.sparse.memorySlots, index: 5)
#else
commandList.setDescriptor(
handleID: voxels.sparse.memorySlotsHandleID, index: 5)
#endif
let offset = GeneralCounters.offset(.rebuiltVoxelCount)
commandList.dispatchIndirect(
buffer: counters.general,
offset: offset)
}
#if os(Windows)
computeUAVBarrier(commandList: commandList)
#endif
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment