Skip to content

Instantly share code, notes, and snippets.

@dondragmer
Last active April 15, 2025 09:21
Show Gist options
  • Save dondragmer/e014834aeb55789c521d9a8fce01dea9 to your computer and use it in GitHub Desktop.
Save dondragmer/e014834aeb55789c521d9a8fce01dea9 to your computer and use it in GitHub Desktop.
An optimized bitonic sorting compute shader. Has no bank conflicts, uses shader model 6.6 wave interstices, and works with different sort sizes and per-wave lane counts.
RWBuffer<uint> SortBuffer : register(u0);
static const uint sSortSize = 1024; // can be any power of up to 1024 (the max threads in a group)
// to avoid all bank conflicts there needs to be a space of padding inserted at every multiple of every power of the wave size
// (if an index is a multiple of several powers of the wave size a pad needs to be added for each)
// the smallest wave size possible is 4, so the most padding needed is (sort size) * (1/4 + 1/16 + 1/64 ...)
// which convergets to (sort size) / 3;
static const uint sGroupsharedSortValuesToPaddingRatio = 3;
groupshared uint sharedSortArray[sSortSize + (sSortSize / sGroupsharedSortValuesToPaddingRatio)];
uint RecursivelyPadGroupsharedIndexToAvoidBankConflicts(uint index)
{
uint paddedIndex = 0;
for (uint offset = index; offset > 0; offset /= WaveGetLaneCount())
{
paddedIndex += offset;
}
return paddedIndex;
}
uint WaveBitonicMerge(uint value, uint sortLevel, bool sortAscending)
{
for (uint stride = 1u << (sortLevel - 1); stride > 0; stride /= 2)
{
uint otherIndex = WaveGetLaneIndex() ^ stride;
bool isLowerHalf = (WaveGetLaneIndex() & stride) == 0;
uint otherValue = WaveReadLaneAt(value, otherIndex);
if ((value > otherValue) == (isLowerHalf == sortAscending))
{
value = otherValue;
}
}
return value;
}
uint WaveBitonicSort(uint value, uint maxSortLevel, bool sortAscending)
{
for (uint sortLevel = 1; sortLevel <= maxSortLevel; sortLevel++)
{
bool subsortAscending = ((WaveGetLaneIndex() >> sortLevel) & 1u) == 0;
value = WaveBitonicMerge(value, sortLevel, subsortAscending == sortAscending);
}
return value;
}
uint GroupBitonicSort(uint value, uint threadIndex)
{
uint sortLevelForFullSort = firstbithigh(sSortSize);
uint waveMaxSortLevel = firstbithigh(WaveGetLaneCount());
uint waveIndex = threadIndex >> waveMaxSortLevel;
uint paddedThreadIndex = RecursivelyPadGroupsharedIndexToAvoidBankConflicts(threadIndex);
// bitonic sort all the values within the wave
value = WaveBitonicSort(value, min(waveMaxSortLevel, sortLevelForFullSort), (waveIndex & 1u) == 0);
// continue the bitonic sort across waves
for (uint currentTopSortLevel = waveMaxSortLevel + 1; currentTopSortLevel <= sortLevelForFullSort; currentTopSortLevel++)
{
sharedSortArray[paddedThreadIndex] = value;
// for larger sorts interleave the values across waves, so each lane of the first wave gets the first value from every other wave
// the second wave gets the second, etc. then do the bitonic merge using wave intrinsics and write the values back
// if the wave size is small, do this recursively by raising the interleaving stride to the power of the needed supersort level
for (uint supersortLevel = (currentTopSortLevel - 1) / waveMaxSortLevel; supersortLevel > 0; supersortLevel--)
{
GroupMemoryBarrierWithGroupSync();
uint logTwoOfCrossWaveStride = supersortLevel * waveMaxSortLevel;
// each lane in the wave reads/writes an element spaced out by the cross wave stride
// but it's possible for those reads to be out of bounds
// when that happens there also won't be enough waves to fill out the stride in the interleaved network
// so wrap those out of bounds lanes around with an offset equal to the wave count to fill in the gaps
// after padding this won't cause any bank conflicts
uint laneGroupMask = (1u << (sortLevelForFullSort - logTwoOfCrossWaveStride)) - 1;
uint interleavedIndex = ((WaveGetLaneIndex() & laneGroupMask) << logTwoOfCrossWaveStride)
+ ((WaveGetLaneIndex() & ~laneGroupMask) << (logTwoOfCrossWaveStride - waveMaxSortLevel));
// each wave gets the index its lanes are accessing offset by one so the indices are interleaved
// once there are enough waves to fill gap between the lane's stride, jump to the next section the array
uint waveGroupMask = (1u << logTwoOfCrossWaveStride) - 1;
interleavedIndex += (waveIndex & waveGroupMask)
+ ((waveIndex & ~waveGroupMask) * WaveGetLaneCount());
uint paddedInterleavedIndex = RecursivelyPadGroupsharedIndexToAvoidBankConflicts(interleavedIndex);
uint subsortLevel = currentTopSortLevel - logTwoOfCrossWaveStride;
bool sortAscending = ((interleavedIndex >> currentTopSortLevel) & 1u) == 0;
sharedSortArray[paddedInterleavedIndex] = WaveBitonicMerge(sharedSortArray[paddedInterleavedIndex],
subsortLevel,
sortAscending);
}
// do the final levels of bitonic merging within the wave
GroupMemoryBarrierWithGroupSync();
bool sortAscending = ((threadIndex >> currentTopSortLevel) & 1u) == 0;
value = WaveBitonicMerge(sharedSortArray[paddedThreadIndex], waveMaxSortLevel, sortAscending);
}
return value;
}
[numthreads(sSortSize, 1, 1)]
void SortEntrypoint(uint3 ThreadID : SV_GroupThreadID, uint3 DispatchThreadID : SV_DispatchThreadID)
{
uint value = SortBuffer[DispatchThreadID.x];
value = GroupBitonicSort(value, ThreadID.x);
SortBuffer[DispatchThreadID.x] = value;
}
Copyright (c) 2024 DJ Shea
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment