Skip to content

Instantly share code, notes, and snippets.

@ncthbrt
Last active September 30, 2024 20:05
Show Gist options
  • Save ncthbrt/3b67f2a36d0d268a6d8cba8732aa8bbb to your computer and use it in GitHub Desktop.
Save ncthbrt/3b67f2a36d0d268a6d8cba8732aa8bbb to your computer and use it in GitHub Desktop.
An example of a generic GPU reduction in WESL
// Module signature that simply exposes the single type T. Could perhaps later be sugared to elide the module in follow-up
// work
mod sig Type {
type T;
}
// Module signature that exposes a single constant `value`. Could perhaps later be sugared to elide the module in follow-up work
mod sig Const<Type: Type> {
const value: Type::T;
}
// Abstract representation of a binary operation.
mod sig BinaryOp<OpElem: Type, LoadElem: Type> {
// This is a common pattern to allow transfer of type information from generic input to output module
type LoadElem : LoadElem::T;
type OpElem : OpElem::T;
fn identityOp() -> OpElem;
fn loadOp(a: LoadElem::T) -> OpElem;
fn binaryOp(a: OpElem, b: OpElem) -> OpElem;
}
// In future, modules representing numbers, vectors, matrices and other built in types
// would be part of the standard library. But lets define some common operations for now
mod sig Number {
type T;
fn add(a: T, b: T) -> T;
fn identity() -> T;
}
mod Sum<N: Number> {
struct T {
sum: N::T;
}
}
mod SumBinaryOp<N: Number> -> BinaryOp<Sum<N>, Sum<N>> {
alias OpElem = Sum<N>::T;
alias LoadElem = Sum<N>::T;
fn identityOp() -> OpElem {
return OpElem();
}
fn loadOp(a: LoadElem) -> OpElem {
return OpElem(a.sum);
}
fn binaryOp(a: OpElem, b: OpElem) -> OpElem {
return OpElem(N::add(a.sum, b.sum));
}
}
mod F32 {
alias T = f32;
fn add(a: T, b: T) -> T {
return a + b;
}
fn identity() -> T {
return 0.0;
}
}
mod U32 {
alias T = u32;
}
// Here we don't care about the exact generic mod values
// passed to BinaryOp as we can extract the underlying types from the module
// members
mod ReduceWorkgroup<Op: BinaryOp<_, _>, WorkSize: Const<U32>, Threads: Const<U32>> {
var <workgroup> work: array<Op::OpElem::T, WorkSize::value>;
fn reduceWorkgroup(localId: u32) {
let workDex = localId << 1u;
for (var step = 1u; step < Threads::value; step <<= 1u) {
workgroupBarrier();
if localId % step == 0u {
work[workDex] = Op::binaryOp(work[workDex], work[workDex + step]);
}
}
}
}
// Same here
mod ReduceBuffer<Op: BinaryOp<_, _>, BlockArea: Const<U32>, WorkSize: Const<U32>, Threads: Const<U32>> {
// extend brings the module members into the namespace
extend ReduceWorkgroup<Op, WorkSize, Threads>;
alias Input = Op::LoadElem::T;
alias Output = Op::OpElem::T;
struct Uniforms {
sourceOffset: u32, // offset in Input elements to start reading in the source
resultOffset: u32, // offset in Output elements to start writing in the results
}
@group(0) @binding(0) var<uniform> u: Uniforms;
@group(0) @binding(1) var<storage, read> src: array<Input>;
@group(0) @binding(2) var<storage, read_write> out: array<Output>;
@group(0) @binding(11) var<storage, read_write> debug: array<f32>; // buffer to hold debug values
override workgroupThreads = 4u;
var <workgroup> work: array<Output, workgroupThreads>;
// reduce a buffer of values to a single value, returned as the last element of the out array
//
// each dispatch does two reductions:
// . each invocation reduces from a src buffer to the workgroup buffer
// . one invocation per workgroup reduces from the workgroup buffer to the out buffer
// the driver issues multiple dispatches until the output is 1 element long
// (subsequent passes uses the output of the previous pass as the src)
// the same output buffer can be used as input and output in subsequent passes
// . start and end indices in the uniforms indicate input and output positions in the buffer
//
@compute
@workgroup_size(workgroupThreads, 1, 1)
fn main(
@builtin(global_invocation_id) grid: vec3<u32>, // coords in the global compute grid
@builtin(local_invocation_index) localIndex: u32, // index inside the this workgroup
@builtin(num_workgroups) numWorkgroups: vec3<u32>, // number of workgroups in this dispatch
@builtin(workgroup_id) workgroupId: vec3<u32> // workgroup id in the dispatch
) {
reduceBufferToWork(grid.xy, localIndex);
let outDex = workgroupId.x + u.resultOffset;
reduceWorkgroup(localIndex);
if localIndex == 0u {
out[outDex] = work[0];
}
}
fn reduceBufferToWork(grid: vec2<u32>, localId: u32) {
var values = fetchSrcBuffer(grid.x);
var v = reduceSrcBlock(values);
work[localId] = v;
}
fn fetchSrcBuffer(gridX: u32) -> array<Output, BlockArea::value> {
let start = u.sourceOffset + (gridX * BlockArea::value);
let end = arrayLength(&src);
var a = array<Output, BlockArea::value>();
for (var i = 0u; i < BlockArea::value; i = i + 1u) {
var idx = i + start;
if idx < end {
a[i] = Op::loadOp(src[idx]);
} else {
a[i] = Op::identityOp();
}
}
return a;
}
fn reduceSrcBlock(a: array<Output, BlockArea::value>) -> Output {
var v = a[0];
for (var i = 1u; i < BlockArea::value; i = i + 1u) {
v = Op::binaryOp(v, a[i]);
}
return v;
}
}
// To actually realize a concrete ReduceBuffer module, we need concrete const values
mod BlockArea -> Const<U32> {
const value: u32 = 4u;
}
mod WorkSize -> Const<U32> {
const value: u32 = 18u;
}
mod Threads -> Const<U32> {
const value: u32 = 10u;
}
// Putting everything together and into the global namespace
extend ReduceBuffer<SumBinaryOp<F32>, BlockArea, WorkSize, Threads>;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment