Skip to content

Instantly share code, notes, and snippets.

@madmann91
Last active June 22, 2022 22:03
Show Gist options
  • Save madmann91/911068852892d76db59d72b288aec2dc to your computer and use it in GitHub Desktop.
Save madmann91/911068852892d76db59d72b288aec2dc to your computer and use it in GitHub Desktop.
Basic BVH traversal with early exit & traversal order optimizations
#version 440
#define STACK_SIZE 32
#define INVALID_HIT_ID uint(-1)
// The primitive data be changed to the desired primitive type
// (including using BVHs as primitives for top-level BVH traversal).
#define PrimitiveData Triangle
#define HitData vec2 // Barycentric coordinates of a hit on a triangle (can be changed to anything)
#define intersect_ray_primitive(ray, prim, hit, tnear, tfar) \
intersect_ray_triangle(ray, prim, hit, tnear, tfar)
struct Triangle {
vec3 v0, v1, v2;
};
struct Ray {
vec3 org;
vec3 dir;
};
struct Hit {
uint id; // Primitive ID
HitData data; // Per-primitive hit information
};
// Note: The layout used for the BVH nodes is such that `first` points to
// the first child of the node for an inner node. The second child is
// assumed to be located at `first + 1`.
// If the children of a node are both leaves, their primitives should form
// a contiguous range.
struct Node {
vec3 min, max;
uint first; // Index of first child if inner node, otherwise index of first primitive
uint prim_count; // Number of primitives (> 0 means leaf, 0 means inner node)
};
layout (binding = 0) buffer Bvh {
Node nodes[];
} bvh;
layout (binding = 1) buffer Primitives {
PrimitiveData data[];
} primitives;
bool is_valid(Hit hit) { return hit.id != INVALID_HIT_ID; }
bool is_leaf(Node node) { return node.prim_count > 0; }
vec2 intersect_ray_box(vec3 org, vec3 inv_dir, vec3 box_min, vec3 box_max, float tnear, float tfar) {
vec3 tmin = (box_min - org) * inv_dir;
vec3 tmax = (box_max - org) * inv_dir;
vec3 t0 = min(tmin, tmax);
vec3 t1 = max(tmin, tmax);
return vec2(
max(t0.x, max(t0.y, max(t0.z, tnear))),
min(t1.x, min(t1.y, min(t1.z, tfar))));
}
bool intersect_ray_triangle(Ray ray, Triangle triangle, out vec2 data, float tnear, inout float tfar) {
vec3 e1 = triangle.v0 - triangle.v1;
vec3 e2 = triangle.v2 - triangle.v0;
vec3 n = cross(e1, e2);
vec3 c = triangle.v0 - ray.org;
vec3 r = cross(ray.dir, c);
float inv_det = 1.0f / dot(n, ray.dir);
float u = dot(r, e2) * inv_det;
float v = dot(r, e1) * inv_det;
float w = 1.0f - u - v;
if (u >= 0 && v >= 0 && w >= 0) {
float t = dot(n, c) * inv_det;
if (t >= tnear && t < tfar) {
data = vec2(u, v);
tfar = t;
return true;
}
}
return false;
}
// Before using this function, make sure your BVH data layout matches those requirements:
// 1. The BVH is not just a single leaf,
// 2. The two children of one inner node are placed contiguously,
// 3. If one inner node has two leaves as children, their primitives form a contiguous range.
Hit intersect_ray_bvh(bool is_any, Ray ray, float tnear, inout float tfar) {
Hit hit;
hit.id = INVALID_HIT_ID;
vec3 inv_dir = 1.0 / ray.dir;
uint stack[STACK_SIZE];
uint stack_size = 0;
uint top = 1;
while (true) {
// Note: This does not look up the parent of those two nodes,
// it directly loads both its children. This is why it won't work with a
// root node that is a leaf.
// If the root node is a leaf, you can either create a dummy inner node to replace it,
// or alternatively you can treat the case where the root node is a leaf separately.
Node left = bvh.nodes[top];
Node right = bvh.nodes[top + 1];
vec2 intr_left = intersect_ray_box(ray.org, inv_dir, left.min, left.max, tnear, tfar);
vec2 intr_right = intersect_ray_box(ray.org, inv_dir, right.min, right.max, tnear, tfar);
bool hit_left = intr_left.y >= intr_left.x;
bool hit_right = intr_right.y >= intr_right.x;
// Directly intersect primitives contained in leaves.
// Note: This assumes that, if both children are leaves,
// their primitives form a contiguous range.
uint prim_count =
(hit_left ? left.prim_count : 0) +
(hit_right ? right.prim_count : 0);
if (prim_count > 0) {
uint first_prim = hit_left && is_leaf(left) ? left.first : right.first;
for (uint i = 0; i < prim_count; ++i) {
PrimitiveData data = primitives.data[first_prim + i];
if (intersect_ray_primitive(ray, data, hit.data, tnear, tfar)) {
hit.id = first_prim + i;
if (is_any)
return hit;
}
}
hit_left = hit_left && !is_leaf(left);
hit_right = hit_right && !is_leaf(right);
}
// Push children on the stack, if any
if (hit_left) {
if (hit_right) {
// Both children are hit
uint first_child = left.first;
uint second_child = right.first;
if (!is_any && intr_left.x > intr_right.x) {
uint tmp = first_child;
first_child = second_child;
second_child = tmp;
}
stack[stack_size++] = second_child;
top = first_child;
} else {
top = left.first;
}
continue;
} else if (hit_right) {
top = right.first;
continue;
}
if (stack_size == 0)
break;
top = stack[--stack_size];
}
return hit;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment