Skip to content

Instantly share code, notes, and snippets.

@guoxx
Forked from natevm/hploc.slang
Created August 27, 2024 15:51
Show Gist options
  • Save guoxx/e37cedad73271c0348edaa730eeb3ad9 to your computer and use it in GitHub Desktop.
Save guoxx/e37cedad73271c0348edaa730eeb3ad9 to your computer and use it in GitHub Desktop.
An implementation of "H-PLOC Hierarchical Parallel Locally-Ordered Clustering for Bounding Volume Hierarchy Construction".
// MIT License
// Copyright (c) 2024 Nathan V. Morrical
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
// The below is an implementation of the H-PLOC algorithm by Benthin et al.,
// for the paper, see the following link: https://gpuopen.com/download/publications/HPLOC.pdf
//
// This code requires a supporting parallel radix sort implementation.
// I recommend using the "One Sweep" algorithm, and HLSL implementation of which
// can be found here: https://github.com/b0nes164/GPUSorting
//
// Alternatively, you can sort the key-value pairs produced in the Morton kernel
// on the CPU, then upload the results back to the device. This will be slower, but
// is a good way to get started with the algorithm.
using namespace gprt;
#include "hploc.slangh"
import sfc;
#define WAVE_SIZE 32
// Results in a search radius of 8, which is what's recommended in the paper.
// Larger radii than this aren't helpful.
#define SEARCH_RADIUS_SHIFT 3
// Assuming at most 128 geometries for now, but this can be increased if needed.
// (Might be able to use the full 255... just need to test it out...)
#define GEOM_ID_BVH2 255
#define INVALID_ID UINT32_MAX
// Shouldn't happen, but useful for debugging purposes.
#define ERROR_OUT_OF_BOUNDS -1
#define ERROR_TIMEOUT -2
// Counts the trailing zeros in a 32-bit integer
uint countTrailingZero(uint x) {
// Check for the special case when x is 0
if (x == 0) {
return 32; // or use a special value to indicate no bits are set
}
return firstbitlow(x);
}
// Finds the "Nth" set bit in a 32-bit integer.
uint __fns(uint mask, uint n) {
uint count = 0;
for (uint i = 0; i < 32; ++i) {
if ((mask & (1 << i)) != 0) {
count++;
if (count == n) {
return i; // Return the position of the nth set bit
}
}
}
return uint(-1); // Return an invalid position if there aren't enough bits set
}
// delta function in sec3 of the paper
// "Fast and Simple Agglomerative LBVH Construction"
// Effectively, this returns where the code at "id" and code at "id+1" differ.
// We do so by xor-ing the two codes. Larger values indicate a larger difference.
uint64_t delta(const int a, const int b, const uint n, Buffer mc)
{
if (a < 0 || b >= n) return (uint64_t)(-1);
uint64_t c = load<uint64_t>(mc, a) ^ load<uint64_t>(mc, b);
if (c == 0) return a ^ (a + 1ull);
return c;
}
// Returns either L-1 or R depending on which parent has the smaller delta value to L or R+1 respectively.
// In effect, this function returns the element which minimizes the expansion of the AABB of the current
// morton range.
uint findParentID(const int a, const int b, const uint n, Buffer mc)
{
if (a == 0 || (b != n && (delta(b, b + 1, n, mc) < delta(a - 1, a, n, mc))))
return b;
else
return a - 1;
}
float getSurfaceArea(float2x3 bounds) {
float3 d = bounds[1] - bounds[0];
float sa = max(2.0f * (d.x * d.y + d.x * d.z + d.y * d.z), 0.0f);
return sa;
}
// Returns the surface area of the AABB containing A and B as a "distance" for the PLOC nearest neighbor search.
// Returned values are guaranteed to be non-negative.
float distanceFct(float2x3 A, float2x3 B) {
if (any(A[1] < A[0]) || any(B[1] < B[0])) return 1e38f; // one box is invalid
float3 d = max(A[1], B[1]) - min(A[0], B[0]);
float sa = max(2.0f * (d.x * d.y + d.x * d.z + d.y * d.z), 0.0);
return sa;
}
// Encodes the lane ID and neighbor ID into a single integer in a way that allows for unified atomic min/max operations.
// Taken from here: https://github.com/RenderKit/embree/blob/8ced524ed3066adfe8175ac15a187f3d6bd26f14/kernels/rthwif/rtbuild/rthwif_embree_builder_ploc.h#L1658
uint encodeRelativeOffset(const uint ID, const uint neighbor)
{
const uint uOffset = neighbor - ID - 1;
return uOffset << 1;
}
// Extracts the lane ID from the encoded value above.
// Taken from here: https://github.com/RenderKit/embree/blob/8ced524ed3066adfe8175ac15a187f3d6bd26f14/kernels/rthwif/rtbuild/rthwif_embree_builder_ploc.h#L1664
int decodeRelativeOffset(const int localID, const uint offset, const uint ID)
{
const uint off = (offset >> 1) + 1;
return localID + (((offset ^ ID) % 2 == 0) ? (int)off : -(int)off);
}
// Extracts the 8 exponent bits from the floating point number that conservatively bound it, avoiding infinity.
uint floatToExponent(float num) {
uint bits = asuint(num);
uint exponentBits = (bits >> 23) & 0xFF; // Mask to get only the exponent bits
// Avoiding 255, since that's infinity.
return uint(max(min(exponentBits + 1, 254u), 2u));
}
// Convert an exponent back to a floating-point scale factor
float exponentToFloat(uint exponent) {
uint bits = exponent << 23;
return asfloat(bits);
}
// Returns an 8-bit mask field to indicate which of the children are internal nodes.
uint getIMask(uint assignedChildren[8], uint numLeaves) {
uint iMask = 0;
for (int i = 0; i < 8; ++i) {
if (assignedChildren[i] != INVALID_ID) {
uint relativeIndex = assignedChildren[i];
if (relativeIndex >= numLeaves) iMask |= (1 << i);
}
}
return iMask;
}
// Returns the geometry ID assigned to the given thread, or -1 if out of bounds
uint GetGeometryID(in HPLOCParams params, uint3 DispatchThreadID) {
// Might need to account for multi-dimensional dispatches later...
uint threadID = DispatchThreadID.x;
// This is a bit naive at the moment, just doing a linear search...
uint numGeoms = params.primPrefix.size;
uint geomID = INVALID_ID;
for (int i = 0; i < numGeoms; ++i) {
uint geomIDPrefixSum = load<uint>(params.primPrefix, i);
if (geomIDPrefixSum <= threadID) geomID++;
else break;
}
if (geomID >= numGeoms) return INVALID_ID;
else return geomID;
}
uint GetPrimitiveID(in HPLOCParams params, uint3 DispatchThreadID, uint geomID) {
// Might need to account for multi-dimensional dispatches later...
uint threadID = DispatchThreadID.x;
// If the current geometry ID represents an inner node, just assume a direct
// correspondance between the thread ID and the BVH2 node ID.
if (geomID == GEOM_ID_BVH2) return threadID;
// Invalid geometry ID
if (geomID < 0 || geomID >= params.primPrefix.size) return INVALID_ID;
// Otherwise, return the primitive ID by subtracting the prefix sum of the previous geometry
uint primOffset = load<uint>(params.primPrefix, geomID);
return threadID - primOffset;
}
// Loads the triangle at the given primitive ID and geometry ID.
bool LoadTriangle(in HPLOCParams params, uint primID, uint geomID, out float3 v[3]) {
v = float3[3](float3(0.0), float3(0.0), float3(0.0));
// Check for invalid geometry id
if (geomID < 0 || geomID >= params.primPrefix.size) return false;
gprt::Buffer triangles = load<gprt::Buffer>(params.triangles, geomID);
gprt::Buffer vertices = load<gprt::Buffer>(params.vertices, geomID);
// Check for invalid primitive ID
if (primID >= triangles.size) return false;
uint3 tri = load<uint3>(triangles, primID);
v[0] = load<float3>(vertices, tri.x);
v[1] = load<float3>(vertices, tri.y);
v[2] = load<float3>(vertices, tri.z);
// Check for invalid vertex data
if (any(isnan(v[0])) || any(isnan(v[1])) || any(isnan(v[2]))) return false;
return true;
}
// Wrapper for the function above, decompresses clusterID into geomID / primID
bool LoadTriangle(in HPLOCParams params, uint clusterID, out float3 v[3]) {
uint geomID = clusterID >> 24;
uint primID = clusterID & 0x00FFFFFF;
return LoadTriangle(params, primID, geomID, v);
};
// Loads the triangle bounds at the given primitive ID and geometry ID.
bool LoadTriangleBounds(in HPLOCParams params, uint primID, uint geomID, out float2x3 bounds) {
bounds = float2x3(float3(FLT_MAX), -float3(FLT_MAX));
float3 v[3];
if (LoadTriangle(params, primID, geomID, v)) {
bounds = float2x3(min(min(v[0], v[1]), v[2]), max(max(v[0], v[1]), v[2]));
return true;
}
// Triangle is either degenerate or geomID/primID is invalid.
return false;
}
// Loads the BVH2 bounds at the given node ID.
bool LoadBvh2Bounds(in HPLOCParams params, uint nodeID, out float2x3 bounds) {
bounds = float2x3(float3(FLT_MAX), -float3(FLT_MAX));
if (nodeID >= params.BVH2.size) return false;
bounds[0] = load<float4>(params.BVH2, nodeID * 2 + 0).xyz;
bounds[1] = load<float4>(params.BVH2, nodeID * 2 + 1).xyz;
return true;
}
// Derives the AABB of either a triangle or an inner node.
// If either the node or primitive is invalid, returns false and an empty AABB.
bool LoadAABB(in HPLOCParams params, uint2 clusterIndex, out float2x3 bounds) {
// Default to an empty AABB
bounds = float2x3(float3(FLT_MAX), -float3(FLT_MAX));
// Cluster index represents an internal BVH2 node
if (clusterIndex.y == GEOM_ID_BVH2) {
return LoadBvh2Bounds(params, clusterIndex.x, bounds);
}
// Cluster index represents a leaf belonging to a geometry
else if (clusterIndex.y < params.primPrefix.size) {
return LoadTriangleBounds(params, clusterIndex.x, clusterIndex.y, bounds);
}
// Could also check here somehow for other types of geometry...
// cluster is invalid
else return false;
}
// Redirects to the above function, extracting the primitive and geometry ID from a compressed cluster ID.
bool LoadAABB(in HPLOCParams params, uint32_t compressedClusterID, out float2x3 bounds) {
uint primID = compressedClusterID & 0x00FFFFFF;
uint geomID = compressedClusterID >> 24;
return LoadAABB(params, uint2(primID, geomID), bounds);
}
// Loads the center of a triangle at the given primitive ID and geometry ID.
bool LoadTriangleCenter(in HPLOCParams params, uint primID, uint geomID, out float3 center) {
center = float3(0.0);
float3 v[3];
if (LoadTriangle(params, primID, geomID, v)) {
center = (v[0] + v[1] + v[2]) / 3.f;
return true;
}
return false;
}
// Loads the center of a BVH2 node at the given node ID.
bool LoadBVH2Center(in HPLOCParams params, uint nodeID, out float3 center) {
center = float3(0.0);
float2x3 bounds;
if (LoadBvh2Bounds(params, nodeID, bounds)) {
center = (bounds[0] + bounds[1]) * 0.5f;
return true;
}
return false;
}
// Loads the center of either a triangle or an inner node.
// If either the node or primitive is invalid, returns false and an empty center.
bool LoadCenter(in HPLOCParams params, uint2 clusterIndex, out float3 center) {
// Default center as 0
center = float3(0.0);
// Cluster index represents an internal BVH2 node
if (clusterIndex.y == GEOM_ID_BVH2) {
return LoadBVH2Center(params, clusterIndex.x, center);
}
// Cluster index represents a leaf belonging to a geometry
else if (clusterIndex.y < params.primPrefix.size) {
return LoadTriangleCenter(params, clusterIndex.x, clusterIndex.y, center);
}
// Could also check here somehow for other types of geometry...
// cluster is invalid
else return false;
}
// Computes the bounding box containing all geometry in the tree.
// Run N threads, where N is the sum of all prims of all geometries.
[shader("compute")]
[numthreads(128, 1, 1)]
void HPLOC_Bounds(uint3 DispatchThreadID: SV_DispatchThreadID, uniform HPLOCParams params)
{
uint threadID = DispatchThreadID.x;
if (threadID >= params.N) return;
// Determine the geometry and primitive ID for the current thread
uint geomID = GetGeometryID(params, DispatchThreadID);
uint primID = GetPrimitiveID(params, DispatchThreadID, geomID);
// Fetch the bounding box for the current primitive
// (if not valid, bounds will be set to an empty AABB)
float2x3 bounds;
bool valid = LoadAABB(params, uint2(primID, geomID), bounds);
// Use local wave reduction to reduce global atomic contention
uint numValidPrims = WaveActiveCountBits(valid);
bounds[0] = WaveActiveMin(bounds[0]);
bounds[1] = WaveActiveMax(bounds[1]);
// Only the first lane of each wave will update the global bounds
if (WaveIsFirstLane()) {
// Global atomic min/max to update global root bounds
atomicMin32f(params.rootBounds, 0, bounds[0]);
atomicMax32f(params.rootBounds, 1, bounds[1]);
// Count the number of valid primitives.
atomicAdd(params.AC, 0, numValidPrims);
}
}
uint64_t separate_bits_64(uint64_t n)
{
n &= 0b0000000000000000000000000000000000000000001111111111111111111111ull;
n = (n ^ (n << 32)) & 0b1111111111111111000000000000000000000000000000001111111111111111ull;
n = (n ^ (n << 16)) & 0b0000000011111111000000000000000011111111000000000000000011111111ull;
n = (n ^ (n << 8)) & 0b1111000000001111000000001111000000001111000000001111000000001111ull;
n = (n ^ (n << 4)) & 0b0011000011000011000011000011000011000011000011000011000011000011ull;
n = (n ^ (n << 2)) & 0b1001001001001001001001001001001001001001001001001001001001001001ull;
return n;
};
inline uint64_t morton64_encode(float3 xyz)
{
uint64_t3 ixyz = uint64_t3(xyz * (float)(1ull << 20));
return separate_bits_64(ixyz.x) | (separate_bits_64(ixyz.y) << 1) | (separate_bits_64(ixyz.z) << 2);
}
// Computes the morton code for each primitive of each geometry.
// If a primitive is invalid, it will receive a code of UINT32_MAX, causing the
// code to be reordered to the end of the list during the radix sort.
[shader("compute")]
[numthreads(128, 1, 1)]
void HPLOC_SFC(uint3 DispatchThreadID: SV_DispatchThreadID, uint3 GroupID: SV_GroupID, uniform HPLOCParams params)
{
uint threadID = DispatchThreadID.x;
if (threadID >= params.N) return;
// Determine the geometry and primitive ID for the current thread
uint geomID = GetGeometryID(params, DispatchThreadID);
uint primID = GetPrimitiveID(params, DispatchThreadID, geomID);
uint2 clusterIndex = uint2(primID, geomID);
if (primID >= params.N) return;
// Transform center into 0-1 range
float3 aabbMin = load<float3>(params.rootBounds, 0);
float3 aabbMax = load<float3>(params.rootBounds, 1);
float3 center;
bool valid = LoadCenter(params, clusterIndex, center);
// Default to a code that throws invalid primitives to the end
uint64_t code = (valid) ? morton64_encode(center) : UINT64_MAX;
// Store the sfc code as the key, and the cluster index as the value
store<uint64_t>(params.C, threadID, code);
store<int2>(params.I, threadID, clusterIndex);
// Also, initialize the scratch "parent IDs" buffer to an invalid state.
// As we walk up the tree from the leaves, this will be used to deactivate lanes.
// While here, clear our parent ID buffer
store<uint>(params.pID, threadID, INVALID_ID);
// Array of index pairs for BVH2 to BVH8 conversion
if (threadID == 0) {
// Allocate the wide root node
uint64_t pair = /*root bvh2 cluster ID*/ (uint64_t(GEOM_ID_BVH2 << 24) << 32ull) | /*where bvh8 root will go*/ (0ull);
store<uint64_t>(params.indexPairs, threadID, pair);
store<int>(params.AC, 3, 1); // global counter of allocated BVH8 nodes
}
else {
uint64_t pair = UINT64_MAX; // invalid pair
store<uint64_t>(params.indexPairs, threadID, pair);
}
}
// Todo... try to remove this shared memory dependency
groupshared uint cached_neighbor[WAVE_SIZE];
// NN < -findNearestNeighbor(numPrims, CI, B); // inlined below
uint findNearestNeighbor(uint numPrims, float2x3 bounds) {
uint localWaveID = WaveGetLaneIndex();
cached_neighbor[localWaveID] = INVALID_ID;
GroupMemoryBarrierWithWaveSync();
const uint SEARCH_RADIUS = (uint)1 << SEARCH_RADIUS_SHIFT;
const uint encode_mask = ~(((uint)1 << (SEARCH_RADIUS_SHIFT + 1)) - 1);
uint min_area_index = -1;
for (uint r = 1; r <= SEARCH_RADIUS; r++)
{
float2x3 neighborBounds = WaveReadLaneAt(bounds, localWaveID + r);
if ((localWaveID + r) < numPrims)
{
float new_area = distanceFct(bounds, neighborBounds); // note, currently returns infinite "cost" for invalid bounds
uint new_area_i = ((asuint(new_area) << 1) & encode_mask);
const uint encode0 = encodeRelativeOffset(localWaveID, localWaveID + r);
const uint new_area_index0 = new_area_i | encode0 | (localWaveID & 1);
const uint new_area_index1 = new_area_i | encode0 | (((localWaveID + r) & 1) ^ 1);
min_area_index = min(min_area_index, new_area_index0);
InterlockedMin(cached_neighbor[localWaveID + r], new_area_index1);
}
}
InterlockedMin(cached_neighbor[localWaveID], min_area_index);
uint neighbor = cached_neighbor[localWaveID];
return neighbor;
}
// numPrims <- mergeClustersCreateBVH2Node(numPrims, NN, CI, B) // Inlined below...
uint mergeClustersCreateBVH2Node(uint numPrims, uint NN, inout uint2 CI, inout float2x3 bounds, HPLOCParams params) {
uint localWaveID = WaveGetLaneIndex();
const uint decode_mask = (((uint)1 << (SEARCH_RADIUS_SHIFT + 1)) - 1);
const uint n_i = decodeRelativeOffset(localWaveID, WaveReadLaneAt(NN, localWaveID) & decode_mask, localWaveID);
const uint n_i_n_i = decodeRelativeOffset(n_i, WaveReadLaneAt(NN, n_i) & decode_mask, n_i);
bool symmetricMatchFound = (localWaveID == n_i_n_i); // true if our neighbor thinks we're their nearest neighbor too
bool laneIsLeftNeighbor = localWaveID < n_i; // true if we're the smaller of the pair
const uint2 leftCluster = CI;
const uint2 rightCluster = WaveReadLaneAt(CI, n_i);
const float2x3 leftBounds = bounds;
const float2x3 rightBounds = WaveReadLaneAt(bounds, n_i);
// First, determine how many nodes we'll be generating in this wave.
// Only the left of neighbor pairs will generate nodes. The right of the pair will clear.
bool laneHasCluster = (localWaveID < numPrims);
bool laneIsCreatingNode = ((laneHasCluster) && (symmetricMatchFound) && (laneIsLeftNeighbor));
uint numToAppend = WaveActiveCountBits(laneIsCreatingNode);
uint numNodesAllocated = INVALID_ID;
// Then, use an atomic counter to allocate the new number of BVH2 nodes.
if (localWaveID == 0) numNodesAllocated = atomicAdd(params.AC, 1, numToAppend);
numNodesAllocated = WaveReadLaneFirst(numNodesAllocated);
// Because BVH2 nodes are generated from the bottom up, we reverse the address here
// to ensure the root node appears first.
int baseNodeOffset = (((params.N - 1) - numToAppend) - numNodesAllocated);
// This should never happen, but if it does, we need to set the error bit and stop.
if (numNodesAllocated >= N - 1) {
store<int>(params.AC, 5, ERROR_OUT_OF_BOUNDS);
return 0;
}
// New nodes are appended after leaves and previous appended nodes
uint bvh2IndexPrefix = WavePrefixCountBits(laneIsCreatingNode);
uint bvh2Index = baseNodeOffset + bvh2IndexPrefix;
uint2 newCI = uint2(INVALID_ID, INVALID_ID);
if (laneHasCluster) {
newCI = CI;
if (symmetricMatchFound) {
if (laneIsLeftNeighbor) {
// Merge the two clusters into a new BVH2 node
float2x3 new_bounds;
new_bounds[0] = min(leftBounds[0], rightBounds[0]);
new_bounds[1] = max(leftBounds[1], rightBounds[1]);
const BVH2Node newNode = BVH2Node(leftCluster, rightCluster, new_bounds);
store<BVH2Node>(params.BVH2, bvh2Index, newNode);
bounds = new_bounds;
newCI = uint2(bvh2Index, GEOM_ID_BVH2); // Note, marking the merged cluster type here as a BVH2 node
}
else
newCI = uint2(INVALID_ID, INVALID_ID); // Second item of pair with the larger index disables the slot
}
}
// Compact away cleared nodes
// (todo, refactor...)
const uint flag = (newCI.x != INVALID_ID) ? 1 : 0;
const uint ps = laneHasCluster ? flag : 0;
const uint total_reduction = WaveActiveCountBits(bool(ps));
const uint compactionShuffle = __fns(WaveActiveBallot(bool(ps)).x, localWaveID + 1);
bounds = WaveReadLaneAt(bounds, compactionShuffle);
CI = newCI.x != INVALID_ID ? newCI : CI;
CI = WaveReadLaneAt(newCI, compactionShuffle);
return total_reduction;
}
// Loads cluster indices from the buffer into the lanes of the wave.
// If the lane ID is less than offset or more than end-start, CI is returned
// unmodified.
// Returns the number of valid indices loaded.
uint loadIndices(uint start, uint end, Buffer I, inout uint2 CI, uint offset)
{
uint localWaveID = WaveGetLaneIndex();
uint laneCount = WaveGetLaneCount();
uint numIndices = min(end - start, uint(laneCount / 2));
uint indexID = localWaveID - offset;
bool laneActive = (indexID >= 0 && indexID < numIndices);
if (laneActive)
CI = load<uint2>(I, start + indexID);
uint numValid = WaveActiveCountBits(laneActive && CI.x != INVALID_ID);
numIndices = min(numIndices, numValid);
return numIndices;
}
// Note, call this for the total number of primitives *before* PLOC merged them.
// This needs to store back the cleared "CI" values.
// When calling this function, CI registers should be compact across the wave:
// size([CI0, CI1, ... CI_(numPrims-1), -1, -1, ... -1]) = origNumPrims
void storeIndices(int origNumPrims, uint2 CI, Buffer I, uint LStart) {
uint localWaveID = WaveGetLaneIndex();
if (localWaveID < origNumPrims) {
store<uint2>(I, LStart + localWaveID, CI);
}
}
void plocMerge(uint selectedLaneID, uint L, uint R, uint S, bool final, uniform HPLOCParams params) {
uint localWaveID = WaveGetLaneIndex();
uint laneCount = WaveGetLaneCount();
// Broadcast the range and split provided by "laneID" to the entire wave
uint LStart = WaveReadLaneAt(L, selectedLaneID);
uint REnd = WaveReadLaneAt(R, selectedLaneID) + 1; // (need to add +1 to make the interval half-open, see page 7)
uint LEnd = WaveReadLaneAt(S, selectedLaneID);
uint RStart = LEnd; // (same as LEnd)
// Per-wave cluster indices, initialized as invalid
// (In this implementation, the "Y" component stores the type (leaf, inner node) as well as the geometry ID.)
uint2 CI = uint2(INVALID_ID, INVALID_ID); // [CI0, CI1, ... CI_(WAVE_SIZE-1)];
uint numLeft = loadIndices(LStart, LEnd, params.I, CI, 0);
uint numRight = loadIndices(RStart, REnd, params.I, CI, numLeft);
uint numPrims = numLeft + numRight;
// (Caching the AABB bounds here to avoid redundant loads)
float2x3 bounds;
bool validBounds = LoadAABB(params, CI, bounds);
uint THRESHOLD = WaveReadLaneAt(final, selectedLaneID) == true ? 1 : laneCount / 2;
while (numPrims > THRESHOLD) {
uint NN = findNearestNeighbor(numPrims, bounds);
numPrims = mergeClustersCreateBVH2Node(numPrims, NN, CI, bounds, params);
}
/* (small typo in paper, this should be numLeft + numRight, not numPrims) */
storeIndices(numLeft + numRight, CI, params.I, LStart);
}
// Builds a binary BVH from the sorted (mortoncode / clusterIndex) pairs, following algorithm 1 in the H-PLOC paper.
// Run for N threads, where N is the number of primitives.
[shader("compute")]
[numthreads(WAVE_SIZE, 1, 1)]
void HPLOC_Build(uint3 DispatchThreadID: SV_DispatchThreadID, uint3 GroupThreadID: SV_GroupThreadID, uint3 GroupID: SV_GroupID, uniform HPLOCParams params)
{
using namespace gprt;
uint i = DispatchThreadID.x;
// The total number of leaves in the tree
uint N = params.N;
// The current node's (inclusive) child node range [L;R].
uint L = i, R = i;
// Initially, all lanes that correspond to codes are active.
// (Note, we will have some helper lanes beyond this range, rounded up to the wave size)
bool laneActive = i < N;
// Do bottom-up traversal as long as there are active lanes in wave
while (WaveActiveAnyTrue(laneActive)) {
uint split = INVALID_ID;
if (laneActive) {
int previousID = INVALID_ID;
if (findParentID(L, R, N, params.C) == R) {
previousID = atomicExchange(params.pID, R, L);
if (previousID != INVALID_ID) {
split = R + 1;
R = previousID;
}
}
else {
previousID = atomicExchange(params.pID, L - 1, R);
if (previousID != INVALID_ID) {
split = L;
L = previousID;
}
}
if (previousID == INVALID_ID) laneActive = false;
}
uint size = R - L + 1;
bool final = laneActive && size == N; // Reached top of Morton tree, need to finish BVH2
uint waveMask = WaveActiveBallot(laneActive && (size > WAVE_SIZE / 2) || final).x;
while (waveMask != 0) {
uint laneID = countTrailingZero(waveMask);
plocMerge(laneID, L, R, split, final, params); // Wave-based PLOC++ (Algorithm 2)
waveMask = waveMask & (waveMask - 1); // Done with current lane
}
}
}
// A special stack for the BVH2->BVH8 conversion algorithm.
struct ChildStack <let WIDTH : uint> {
uint numLeaves;
uint numNodes;
// Cluster IDs representing inner nodes and leaves
uint32_t leaves[WIDTH];
uint32_t nodes[WIDTH];
__init() {
numLeaves = 0;
numNodes = 0;
}
uint size() {
return numNodes + numLeaves;
}
[mutating]
bool pushLeaf(uint32_t leaf) {
if (numNodes + numLeaves >= WIDTH) return false;
leaves[numLeaves++] = leaf;
return true;
}
[mutating]
bool pushNode(uint32_t node) {
if (numNodes + numLeaves >= WIDTH) return false;
nodes[numNodes++] = node;
return true;
}
[mutating]
bool pushCluster(uint32_t clusterID) {
if (clusterID >> 24 == GEOM_ID_BVH2) return pushNode(clusterID);
else return pushLeaf(clusterID);
}
[mutating]
uint32_t popNode() {
if (numNodes == 0) return INVALID_ID;
return nodes[--numNodes];
}
// We can only open one of our inner nodes if:
// 1 - we have an inner node to open
// 1 - after opening that inner node, we have two open child slots
bool canOpenInnerNode() {
// note, popping will subtract 1 from numNodes.
// therefore, we just need to guarantee one more open slot.
bool hasRoom = numNodes + numLeaves < 8;
bool hasInnerNode = numNodes > 0;
return hasRoom && hasInnerNode;
}
}
struct Stack<let N : uint> {
uint top;
uint data[N];
__init() {
top = 0;
}
[mutating]
bool push(uint val) {
if (top >= N) return false;
data[top++] = val;
return true;
}
[mutating]
uint pop() {
if (top == 0) return INVALID_ID;
return data[--top];
}
bool isEmpty() {
return top == 0;
}
};
// Implements the HPLOC BVH2->BVH8 conversion algorithm.
[shader("compute")]
[numthreads(128, 1, 1)]
void HPLOC_ToBVH8(uint3 DispatchThreadID: SV_DispatchThreadID, uniform HPLOCParams params)
{
int failsafe = 10000;
uint32_t bvh2ClusterID = INVALID_ID;
uint32_t bvh8ClusterID = INVALID_ID;
int threadID = DispatchThreadID.x;
if (threadID >= params.N) return;
while (true) {
int signal = load<int>(params.AC, 5);
if (signal != 0) return; // break out early if we're signaled to stop
uint64_t val = load<uint64_t>(params.indexPairs, threadID);
bvh2ClusterID = uint32_t(val >> 32ull);
bvh8ClusterID = uint32_t(val);
// temp failsafe to prevent infinite loops
if (failsafe-- <= 0) { store<int>(params.AC, 5, ERROR_TIMEOUT); break;}
// Only continue if we were assigned an inner BVH2 node
// (assuming architecture guarantees forward progress)
if (bvh2ClusterID == INVALID_ID) continue;
// Keep worker thread going until assigned a *leaf* node.
if ((bvh2ClusterID >> 24) != GEOM_ID_BVH2) {
// Once we have a leaf, store it and return.
BVH8Leaf leaf;
leaf.clusterID = bvh2ClusterID;
LoadTriangle(params, bvh2ClusterID, leaf.v);
store<BVH8Leaf>(params.BVH8L, bvh8ClusterID, leaf);
break;
}
// At this point, the "X" of the pair represents an address to our BVH8 node,
// and the "Y" represents an inner BVH2 node which needs collapsing.
// From here, our goal is to collect up to 8 children
var children = ChildStack<8>();
// Push the children of the assigned BVH2 node onto a fresh stack, ordered by surface area.
BVH2Node bvh2Node = load<BVH2Node>(params.BVH2, bvh2ClusterID & 0x00FFFFFF);
float2x3 bounds[2];
uint2 childClusters = uint2(bvh2Node.getChildCluster(0), bvh2Node.getChildCluster(1));
LoadAABB(params, childClusters[0], bounds[0]);
LoadAABB(params, childClusters[1], bounds[1]);
float2 SA = float2(getSurfaceArea(bounds[0]), getSurfaceArea(bounds[1]));
uint2 order = (SA[0] > SA[1]) ? uint2(0, 1) : uint2(1, 0);
children.pushCluster(childClusters[order[0]]);
children.pushCluster(childClusters[order[1]]);
// Traverse the BVH2 tree to collect up to 8 children
while (children.canOpenInnerNode()) {
// By popping a node here, we are effectively dropping it from the tree.
uint nodeID = children.popNode();
// Now we need to add back the two children of that node.
BVH2Node innerNode = load<BVH2Node>(params.BVH2, nodeID & 0x00FFFFFF);
// Determine which of this nodes children has the largest area...
float2x3 bounds[2];
uint2 childClusters = uint2(innerNode.getChildCluster(0), innerNode.getChildCluster(1));
LoadAABB(params, childClusters[0], bounds[0]);
LoadAABB(params, childClusters[1], bounds[1]);
float2 SA = float2(getSurfaceArea(bounds[0]), getSurfaceArea(bounds[1]));
uint2 order = (SA[0] > SA[1]) ? uint2(0, 1) : uint2(1, 0);
children.pushCluster(childClusters[order[0]]);
children.pushCluster(childClusters[order[1]]);
}
// Allocate a worker thread for every child we collected
// (excluding the current thread which we will reuse)
uint baseWorkerOffset = atomicAdd(params.AC, 2, (children.numNodes + children.numLeaves) - 1);
// Also allocate slots for our newly created inner BVH8 nodes and leaf clusters
uint childNodeBaseIndex = atomicAdd(params.AC, 3, children.numNodes);
uint primitiveBaseIndex = atomicAdd(params.AC, 4, children.numLeaves);
// Shouldn't happen, but useful for debugging purposes.
if (childNodeBaseIndex + children.numNodes > params.BVH8.size) {
store<int>(params.AC, 5, ERROR_OUT_OF_BOUNDS);
return;
}
// First, assign the leaves to contiguous worker threads. These threads will generate the leaves, then terminate.
// Then, contiguously assign the inner nodes to subsequent worker threads.
for (int i = 0; i < (children.numLeaves + children.numNodes); ++i) {
int j = i - children.numLeaves;
uint workerAddr = (i == 0) ? threadID : baseWorkerOffset + i;
uint64_t newPair;
if (i < children.numLeaves)
newPair = (uint64_t(children.leaves[i]) << 32ull) | (uint64_t(primitiveBaseIndex + i));
else
newPair = (uint64_t(children.nodes[j] ) << 32ull) | (uint64_t(childNodeBaseIndex + j));
store<uint64_t>(params.indexPairs, workerAddr, newPair);
}
// At this point, we're ready to generate the actual BVH8 node.
// First, determine the bounds of the children, and a common bound containing them all.
float2x3 parentBounds = float2x3(float3(FLT_MAX), -float3(FLT_MAX));
float2x3 childBounds[8];
for (int i = 0; i < children.size(); ++i) {
uint32_t clusterID = (i < children.numLeaves) ? children.leaves[i] : children.nodes[i - children.numLeaves];
LoadAABB(params, clusterID, childBounds[i]);
parentBounds[0] = min(parentBounds[0], childBounds[i][0]);
parentBounds[1] = max(parentBounds[1], childBounds[i][1]);
}
// Now, we "auction off" the children to the octants of the parent node.
float payoffs[8][8];
uint32_t assignments[8];
float prices[8];
for (int i = 0; i < 8; ++i) {
assignments[i] = INVALID_ID;
prices[i] = 0.0;
}
// Table costs(c,s) where each cell indicates the cost of placing a
// particular child node "c" in a particular child slot "s" (0-7).
float3 parentCentroid = (parentBounds[0] + parentBounds[1]) * 0.5f;
for (int c = 0; c < children.size(); ++c) {
float3 childCentroid = (childBounds[c][0] + childBounds[c][1]) * 0.5f;
for (int s = 0; s < 8; ++s) {
float3 DS;
DS.x = (s & 1) == 0 ? 1.0 : -1.0;
DS.y = (s & 2) == 0 ? 1.0 : -1.0;
DS.z = (s & 4) == 0 ? 1.0 : -1.0;
payoffs[c][s] = dot(childCentroid - parentCentroid, DS);
}
}
float epsilon = 1.f / children.size();
Stack<8> bidders;
for (int i = 0; i < children.size(); ++i) {
bidders.push(i);
}
while (!bidders.isEmpty()) {
if (failsafe-- <= 0) { store<int>(params.AC, 5, ERROR_TIMEOUT); return; }
int c = bidders.pop();
float winningReward = -FLT_MAX;
float secondWinningReward = -FLT_MAX;
int winningSlot = -1; // j
int secondWinningSlot = -1; // k
for (int s = 0; s < 8; ++s) {
float reward = payoffs[c][s] - prices[s];
if (reward > winningReward) {
winningReward = reward;
secondWinningReward = winningReward;
winningSlot = s;
secondWinningSlot = winningSlot;
}
else if (reward > secondWinningReward) {
secondWinningReward = reward;
secondWinningSlot = s;
}
}
// Raise the bid by the level at which bidder c is different from first and second winning items.
prices[winningSlot] += (winningReward - secondWinningReward) + epsilon;
int previousAssignment = assignments[winningSlot];
assignments[winningSlot] = c;
if (previousAssignment != INVALID_ID) {
bidders.push(previousAssignment);
}
}
// Now that the auction is complete, we can assign the children to the BVH8 node.
BVH8Node node;
// Quantize to save memory and reduce bandwidth requirements
float3 origin = parentBounds[0];
float3 diag = parentBounds[1] - parentBounds[0];
uint sxExp = floatToExponent(diag.x);
uint syExp = floatToExponent(diag.y);
uint szExp = floatToExponent(diag.z);
uint iMask = getIMask(assignments, children.numLeaves);
node.orxy = uint64_t(asuint(origin.x)) << 32ull | asuint(origin.y);
node.oeim = uint64_t(asuint(origin.z)) << 32ull |
uint64_t(sxExp) << 24ull |
uint64_t(syExp) << 16ull |
uint64_t(szExp) << 8ull |
uint64_t(iMask) << 0ull;
// common base addresses
node.ntbi = uint64_t(childNodeBaseIndex) << 32ull | uint64_t(primitiveBaseIndex);
// initialize child fields
node.meta = node.qlox = node.qloy = node.qloz = node.qhix = node.qhiy = node.qhiz = 0;
// Assign child ordering to the meta field
for (int i = 0; i < 8; ++i) {
uint child = assignments[i];
if (child == INVALID_ID) continue; // empty slot
else if (child < children.numLeaves) {
// leaf node
// high 3 bits unary encode num prims
// low 5 bits store child slot index (ranging in 0...23)
// (todo... store up to 24 triangles, 3 per child slot...)
uint meta = (1 << 5) | (child);
node.meta |= (uint64_t(meta) << (i * 8));
}
else {
// inner node
// high 3 bits of meta are 001.
// low 5 bits store child slot index + 24.
uint meta = (1 << 5) | ((child - children.numLeaves) + 24);
node.meta |= (uint64_t(meta) << (i * 8));
}
// Quantize child AABBs and fill out the meta fields
float3 scale = float3(exponentToFloat(sxExp), exponentToFloat(syExp), exponentToFloat(szExp));
// The child node's box relative to the parent
float3 scaledMin = (max(childBounds[child][0] - origin, float3(0.0f)) / scale) * 256.0f;
float3 scaledMax = (max(childBounds[child][1] - origin, float3(0.0f)) / scale) * 256.0f;
// Apply floor and ceil to ensure conservative quantization
uint3 childQLo = uint3(
uint(min(floor(scaledMin.x), 255.f)),
uint(min(floor(scaledMin.y), 255.f)),
uint(min(floor(scaledMin.z), 255.f))
);
uint3 childQHi = uint3(
uint(min(ceil(scaledMax.x), 255.f)),
uint(min(ceil(scaledMax.y), 255.f)),
uint(min(ceil(scaledMax.z), 255.f))
);
// Store the quantized child fields
node.qlox |= uint64_t(childQLo.x) << (i * 8);
node.qloy |= uint64_t(childQLo.y) << (i * 8);
node.qloz |= uint64_t(childQLo.z) << (i * 8);
node.qhix |= uint64_t(childQHi.x) << (i * 8);
node.qhiy |= uint64_t(childQHi.y) << (i * 8);
node.qhiz |= uint64_t(childQHi.z) << (i * 8);
}
// Finally, store the BVH8 node
store<BVH8Node>(params.BVH8, bvh8ClusterID, node);
}
}
// MIT License
// Copyright (c) 2023 Nathan V. Morrical
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
#pragma once
#include "gprt.h"
using namespace gprt;
struct HPLOCParams {
int N; // Number of primitives across all referenced geometries
Buffer BVH2; // BVH2 nodes, excluding leaves ((N - 1) x 32 bytes)
Buffer BVH8; // BVH8 nodes (ceil((N x 2 - 1) / 8) x 80 bytes)
Buffer BVH8L; // BVH8 leaves (N x 64 bytes)
Buffer AC; // Some atomic counters for allocation and scheduling, all initialized to 0 (6 x 4 bytes)
Buffer I; // Cluster indices reordered by space filling curve codes (N x 4 bytes)
Buffer C; // Space filling curve codes, sorted in ascending order (N x 8 bytes)
Buffer pID; // BVH2 parent IDs. initialized to -1. (N x 4 bytes)
Buffer indexPairs; // BVH2 -> BVH8 pairs (N x 8bytes)
// For geometry
Buffer rootBounds; // Two float3's storing aabb of trimesh
Buffer primPrefix; // An exclusive prefix sum of the prim counts in each geometry
Buffer triangles; // Triangle index buffers, one handle per geometry
Buffer vertices; // Triangle vertex buffers, one handle per geometry
};
// A 32 byte BVH2 node structure
struct BVH2Node {
// [32b xlo] [32b ylo] [32b zlo] [1b - isInner] [7b - L Geom ID] [24b - L Prim ID]
float4 aabbMinAndL;
// [32b xhi] [32b yhi] [32b zhi] [1b - isInner] [7b - R Geom ID] [24b - R Prim ID]
float4 aabbMaxAndR;
#if defined(__SLANG_COMPILER__)
__init(int2 leftCluster, int2 rightCluster, float2x3 new_bounds) {
aabbMinAndL.xyz = new_bounds[0];
aabbMaxAndR.xyz = new_bounds[1];
uint32_t leftIndex = (leftCluster.x & 0x00FFFFFF) | ((leftCluster.y & 0x000000FF) << 24);
uint32_t rightIndex = (rightCluster.x & 0x00FFFFFF) | ((rightCluster.y & 0x000000FF) << 24);
aabbMinAndL.w = asfloat(leftIndex);
aabbMaxAndR.w = asfloat(rightIndex);
}
#endif
float2x3 getBounds() {
#if defined(__SLANG_COMPILER__)
return float2x3(aabbMinAndL.xyz, aabbMaxAndR.xyz);
#else
float2x3 bounds;
bounds[0] = aabbMinAndL.xyz();
bounds[1] = aabbMaxAndR.xyz();
return bounds;
#endif
}
bool isChildLeaf(int childIndex) {
#if defined(__SLANG_COMPILER__)
uint32_t index = (childIndex == 0) ? asuint(aabbMinAndL.w) : asuint(aabbMaxAndR.w);
#else
uint32_t index = (childIndex == 0) ? *(uint32_t *)&aabbMinAndL.w : *(uint32_t*)&aabbMaxAndR.w;
#endif
uint32_t top8Bits = (index >> 24);
if (top8Bits == 255) return false;
return true;
}
uint32_t getChildPrimID(int childIndex) {
#if defined(__SLANG_COMPILER__)
uint32_t index = (childIndex == 0) ? asuint(aabbMinAndL.w) : asuint(aabbMaxAndR.w);
#else
uint32_t index = (childIndex == 0) ? *(uint32_t*)&aabbMinAndL.w : *(uint32_t*)&aabbMaxAndR.w;
#endif
return index & 0x00FFFFFF;
}
uint32_t getChildGeomID(int childIndex) {
#if defined(__SLANG_COMPILER__)
uint32_t index = (childIndex == 0) ? asuint(aabbMinAndL.w) : asuint(aabbMaxAndR.w);
#else
uint32_t index = (childIndex == 0) ? *(uint32_t*)&aabbMinAndL.w : *(uint32_t*)&aabbMaxAndR.w;
#endif
return (index >> 24) & 0x000000FF;
}
uint32_t getChildCluster(int childIndex) {
#if defined(__SLANG_COMPILER__)
uint32_t index = (childIndex == 0) ? asuint(aabbMinAndL.w) : asuint(aabbMaxAndR.w);
#else
uint32_t index = (childIndex == 0) ? *(uint32_t*)&aabbMinAndL.w : *(uint32_t*)&aabbMaxAndR.w;
#endif
return index;
}
};
// 80 bytes
// Just as an example, following from here: https://users.aalto.fi/~laines9/publications/ylitie2017hpg_paper.pdf
struct BVH8Node {
// To interpret "meta" field:
// Empty child slot: field set to 00000000
// Internal node: high 3 bits 001 while low 5 bits store child slot index + 24. (values range in 24-31)
// Leaf node: high 3 bits store number of triangles using unary encoding, and low 5 bits store the
// index of first triangle relative to the triangle base index (ranging in 0...23)
// [32b origin x] [32b origin y]
uint64_t orxy;
// [32b origin z] [8b extent x] [8b extent y] [8b extent z] [8b inner node mask]
uint64_t oeim;
// [32b child node base index] [32b triangle base index]
uint64_t ntbi;
// [8b child 0] [8b child 1] [8b child 2] [8b child 3] [8b child 4] [8b child 5] [8b child 6] [8b child 7]
uint64_t meta;
uint64_t qlox;
uint64_t qloy;
uint64_t qloz;
uint64_t qhix;
uint64_t qhiy;
uint64_t qhiz;
};
struct BVH8Leaf {
// A bit naive at the moment...
//[8b geom ID] [24b prim ID]
float3 v[3];
uint32_t clusterID;
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment