Skip to content

Instantly share code, notes, and snippets.

@rgov
Last active June 28, 2024 00:42
Show Gist options
  • Save rgov/9139d725841670e8cbdf1593d5f369da to your computer and use it in GitHub Desktop.
Save rgov/9139d725841670e8cbdf1593d5f369da to your computer and use it in GitHub Desktop.
Example of how to use Metal SIMD functions to perform a reduction operation
// Adapted from the Metal Shading Language Specification, Version 3.2, p.186
// The version in the PDF has two errors which are corrected here.
#include <metal_stdlib>
using namespace metal;
kernel void
reduce(const device int *input [[buffer(0)]],
device atomic_int *output [[buffer(1)]],
threadgroup int *ldata [[threadgroup(0)]],
uint gid [[thread_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint lsize [[threads_per_threadgroup]],
uint simd_size [[threads_per_simdgroup]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]])
{
// Perform the first level of reduction.
// Read from device memory, write to threadgroup memory.
int val = input[gid];
for (uint s=lsize/simd_size; s>1; s/=simd_size)
{
// Perform per-SIMD partial reduction.
for (uint offset=simd_size/2; offset>0; offset/=2)
val += simd_shuffle_down(val, offset);
// Write per-SIMD partial reduction value to threadgroup memory.
if (simd_lane_id == 0)
ldata[simd_group_id] = val;
// Wait for all partial reductions to complete.
threadgroup_barrier(mem_flags::mem_threadgroup);
val = (lid < s) ? ldata[lid] : 0;
}
// Perform final per-SIMD partial reduction to calculate
// the threadgroup partial reduction result.
for (uint offset=simd_size/2; offset>0; offset/=2)
val += simd_shuffle_down(val, offset);
// Atomically update the reduction result.
if (lid == 0)
atomic_fetch_add_explicit(output, val, memory_order_relaxed);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment