Created
January 20, 2021 23:32
-
-
Save dondragmer/c75a1a50f1cdd00c104d3483375bdb2f to your computer and use it in GitHub Desktop.
An optimized GPU counting sort
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma use_dxc //enable SM 6.0 features, in Unity this is only supported on version 2020.2.0a8 or later with D3D12 enabled | |
#pragma kernel CountTotalsInBlock | |
#pragma kernel BlockCountPostfixSum | |
#pragma kernel CalculateOffsetsForEachKey | |
#pragma kernel FinalSort | |
uint _FirstBitToSort; | |
int _NumElements; | |
int _NumBlocks; | |
bool _ShouldSortPayload; | |
Buffer<uint> KeyInputBuffer; | |
RWBuffer<uint> KeyOutputBuffer; | |
Buffer<uint> PayloadInputBuffer; | |
RWBuffer<uint> PayloadOutputBuffer; | |
RWTexture2D<int4> PerBlockKeyCountsTexture; | |
RWTexture2D<int4> BlockToGlobalKeyOffsetsTexture; | |
//this program assumes 32 lane wide waves (i.e. Nvidia cards), 64 lane waves would require more changes than just adjusting these values | |
static const uint WAVE_SIZE = 32; | |
static const uint HIGHEST_LANE = WAVE_SIZE - 1; | |
static const uint WAVE_SIZE_PLUS_PAD = WAVE_SIZE + 1; | |
static const uint HALF_WAVE_SIZE = WAVE_SIZE / 2; | |
static const uint LOG2_WAVE_SIZE = firstbitlow(WAVE_SIZE); | |
/* | |
* ----------------------------------------------------------------------------------------------------------- | |
*/ | |
groupshared uint totalCountsInBlock[128]; | |
[numthreads(1024, 1, 1)] | |
void CountTotalsInBlock(uint3 threadID : SV_GroupThreadID, uint3 groupID : SV_GroupID) | |
{ | |
uint laneIndex = WaveGetLaneIndex(); //the index of this thread in its wavefront | |
uint waveIndex = threadID.x / WAVE_SIZE; //the index of this wavefront in the group | |
uint globalIndex = threadID.x + (groupID.x * 1024); //the index this thread is loading from the global array | |
uint rawSortKey = KeyInputBuffer[min(globalIndex, _NumElements - 1)]; | |
uint sortKey = (globalIndex < _NumElements) ? ((rawSortKey >> _FirstBitToSort) & 0xFF) : 0xFF; | |
//initialize counts to 0 because we are going to atomically add to them | |
if (threadID.x < 128) | |
{ | |
totalCountsInBlock[threadID.x] = 0; | |
} | |
GroupMemoryBarrierWithGroupSync(); | |
//this will contain the total occurrences in this block for 8 keys | |
//4 sequential keys from the first 128 keys in the lower 16 bits of each element | |
//4 sequential keys from the latter 128 keys in the upper 16 bits of each element | |
uint4 countsForWave; | |
{ | |
uint equalsLaneMask = ~0; //bitmask of lanes in the wave with keys = to this lane's index | |
for (int bit = 0; bit < LOG2_WAVE_SIZE; bit++) | |
{ | |
//start at the third bit so each lane can store have 4 consecutive keys | |
bool isBitSet = sortKey & (1 << (bit + 2)); | |
uint bitSetMask = WaveActiveBallot(isBitSet).x; | |
equalsLaneMask &= (laneIndex.x & (1 << bit)) ? bitSetMask : ~bitSetMask; | |
} | |
//this wave will get the counts in this wave for all 8 permutations of the following bits: | |
bool isBitSet = sortKey & (1 << 0); | |
uint firstBitSetMask = WaveActiveBallot(isBitSet).x; | |
isBitSet = sortKey & (1 << 1); | |
uint secondBitSetMask = WaveActiveBallot(isBitSet).x; | |
isBitSet = sortKey & (1 << 7); | |
uint eighthBitSetMask = WaveActiveBallot(isBitSet).x; | |
for (int secondBit = 0; secondBit < 2; secondBit++) | |
{ | |
secondBitSetMask = ~secondBitSetMask; | |
for (int firstBit = 0; firstBit < 2; firstBit++) | |
{ | |
firstBitSetMask = ~firstBitSetMask; | |
//pack two counts with different 8th bits into the same value | |
uint countA = countbits(equalsLaneMask & firstBitSetMask & secondBitSetMask & ~eighthBitSetMask); | |
countsForWave[firstBit + secondBit * 2] = (countA & 0xFFFF); | |
uint countB = countbits(equalsLaneMask & firstBitSetMask & secondBitSetMask & eighthBitSetMask); | |
countsForWave[firstBit + secondBit * 2] |= (countB << 16); | |
} | |
} | |
} | |
//atomically add the counts from every wave together | |
for (uint subIndex = 0; subIndex < 4; subIndex++) | |
{ | |
uint writeIndex = laneIndex + (subIndex * WAVE_SIZE); | |
InterlockedAdd(totalCountsInBlock[writeIndex], countsForWave[subIndex]); | |
} | |
GroupMemoryBarrierWithGroupSync(); | |
//have the first two waves output the results | |
[branch] | |
if (waveIndex <= 1) | |
{ | |
uint4 countsForBlock; | |
for (int subIndex = 0; subIndex < 4; subIndex++) | |
{ | |
uint readIndex = laneIndex + (subIndex * WAVE_SIZE); | |
countsForBlock[subIndex] = totalCountsInBlock[readIndex]; | |
} | |
//output the total count of each 1-byte key in this group | |
uint4 unpackedCounts = (waveIndex == 0) ? (countsForBlock & 0xFFFF) : (countsForBlock >> 16); | |
PerBlockKeyCountsTexture[uint2(groupID.x, threadID.x)] = (int4) unpackedCounts; | |
//calculate a prefix sum of every key's count which would equal the index that key starts at in the sorted block | |
uint4 countPrefix; | |
countPrefix.x = 0; | |
countPrefix.y = countsForBlock.x; | |
countPrefix.z = countPrefix.y + countsForBlock.y; | |
countPrefix.w = countPrefix.z + countsForBlock.z; | |
//add in the prefix from the other lanes | |
countPrefix += WavePrefixSum(countPrefix.w + countsForBlock.w); | |
//add the total count of the first half of the keys to the second half | |
uint firstHalfTotal = countPrefix.w + countsForBlock.w; | |
firstHalfTotal = WaveReadLaneAt(firstHalfTotal, HIGHEST_LANE); | |
countPrefix += firstHalfTotal << 16; | |
//output the final starting index | |
uint4 unpackedPrefix = (waveIndex == 0) ? (countPrefix & 0xFFFF) : (countPrefix >> 16); | |
BlockToGlobalKeyOffsetsTexture[uint2(groupID.x, threadID.x)] = (int4) unpackedPrefix; | |
} | |
} | |
/* | |
* ----------------------------------------------------------------------------------------------------------- | |
*/ | |
groupshared int4 eachWaveTotals[32]; | |
[numthreads(1024, 1, 1)] | |
void BlockCountPostfixSum(uint3 threadID : SV_GroupThreadID, uint3 groupID : SV_GroupID) | |
{ | |
uint laneIndex = WaveGetLaneIndex(); //the index of this thread in its wavefront | |
uint waveIndex = threadID.x / WAVE_SIZE; //the index of this wavefront in the group | |
int4 blockCounts = (threadID.x < _NumBlocks) ? PerBlockKeyCountsTexture[uint2(threadID.x, groupID.y)] : 0; | |
//calculate the postfix 1024 blocks at a time | |
int4 runningTotals = 0; | |
int startingBlock; | |
for (startingBlock = 0; startingBlock < (_NumBlocks - 1024); startingBlock += 1024) | |
{ | |
int4 blockCountPostfix = WavePrefixSum(blockCounts) + blockCounts; | |
//load the next set of key counts now to maximize the time until we need to use it | |
int blockLoadIndex = threadID.x + startingBlock; | |
blockCounts = (blockLoadIndex + 1024 < _NumBlocks) ? PerBlockKeyCountsTexture[uint2(blockLoadIndex + 1024, groupID.y)] : 0; | |
//have last lane in each wave output the total counts for this wave | |
if (laneIndex == HIGHEST_LANE) | |
{ | |
eachWaveTotals[waveIndex] = blockCountPostfix; | |
} | |
GroupMemoryBarrierWithGroupSync(); | |
//get the totals of all waves before this one | |
int4 allWaveTotals = eachWaveTotals[laneIndex]; | |
int4 previousWaveTotal = (laneIndex < waveIndex) ? allWaveTotals : 0; //only keep totals for waves less than this one | |
previousWaveTotal = WaveActiveSum(previousWaveTotal); | |
blockCountPostfix += previousWaveTotal; | |
if (blockLoadIndex < _NumBlocks) | |
{ | |
PerBlockKeyCountsTexture[uint2(blockLoadIndex, groupID.y)] = blockCountPostfix + runningTotals; | |
} | |
//get totals from all 1024 blocks to add to the next set | |
runningTotals += WaveActiveSum(allWaveTotals); | |
GroupMemoryBarrierWithGroupSync(); | |
} | |
//calculate postfix for final set of blocks | |
int4 blockCountPostfix = WavePrefixSum(blockCounts) + blockCounts; | |
//have last lane in each wave output the total count for this wave | |
if (laneIndex == HIGHEST_LANE) | |
{ | |
eachWaveTotals[waveIndex] = blockCountPostfix; | |
} | |
GroupMemoryBarrierWithGroupSync(); | |
//get the totals of all waves before this one | |
int4 previousWaveTotal = eachWaveTotals[laneIndex]; | |
previousWaveTotal = (laneIndex < waveIndex) ? previousWaveTotal : 0; //only keep totals for waves less than this one | |
previousWaveTotal = WaveActiveSum(previousWaveTotal); | |
blockCountPostfix += previousWaveTotal; | |
int blockLoadIndex = threadID.x + startingBlock; | |
if (blockLoadIndex < _NumBlocks) | |
{ | |
PerBlockKeyCountsTexture[uint2(blockLoadIndex, groupID.y)] = blockCountPostfix + runningTotals; | |
} | |
} | |
/* | |
* ----------------------------------------------------------------------------------------------------------- | |
*/ | |
[numthreads(1, 32, 1)] | |
void CalculateOffsetsForEachKey(uint3 threadID : SV_GroupThreadID, uint3 groupID : SV_GroupID) | |
{ | |
//get the total counts of each key in the entire global array | |
int4 globalKeyCountsA = PerBlockKeyCountsTexture[uint2(_NumBlocks - 1, threadID.y)]; | |
int4 globalKeyCountsB = PerBlockKeyCountsTexture[uint2(_NumBlocks - 1, threadID.y + WAVE_SIZE)]; | |
//get the totals counts of each key in all previous blocks | |
int4 previousBlockKeyCountsA = 0; | |
int4 previousBlockKeyCountsB = 0; | |
if (groupID.x > 0) | |
{ | |
previousBlockKeyCountsA = PerBlockKeyCountsTexture[uint2(groupID.x - 1, threadID.y)]; | |
previousBlockKeyCountsB = PerBlockKeyCountsTexture[uint2(groupID.x - 1, threadID.y + WAVE_SIZE)]; | |
} | |
//get the start index of each key inside the sorted block | |
int4 blockKeyStartIndexA = BlockToGlobalKeyOffsetsTexture[uint2(groupID.x, threadID.y)]; | |
int4 blockKeyStartIndexB = BlockToGlobalKeyOffsetsTexture[uint2(groupID.x, threadID.y + WAVE_SIZE)]; | |
//generate prefix sum of total counts of each key to get each key's the global start index | |
globalKeyCountsA.y += globalKeyCountsA.x; | |
globalKeyCountsA.z += globalKeyCountsA.y; | |
globalKeyCountsA.w += globalKeyCountsA.z; | |
globalKeyCountsB.y += globalKeyCountsB.x; | |
globalKeyCountsB.z += globalKeyCountsB.y; | |
globalKeyCountsB.w += globalKeyCountsB.z; | |
//prefix sum for the first half of the keys | |
int crossLanePrefixSumLower = WavePrefixSum(globalKeyCountsA.w); | |
int4 globalKeyStartIndexA = int4(0, globalKeyCountsA.xyz) + crossLanePrefixSumLower + previousBlockKeyCountsA; | |
BlockToGlobalKeyOffsetsTexture[uint2(groupID.x, threadID.y)] = globalKeyStartIndexA - blockKeyStartIndexA; | |
//prefix sum for the second half of the keys | |
int crossLanePrefixSumUpper = WavePrefixSum(globalKeyCountsB.w); | |
crossLanePrefixSumUpper += WaveReadLaneAt(crossLanePrefixSumLower + globalKeyCountsA.w, HIGHEST_LANE); //add the first half total | |
int4 globalKeyStartIndexB = int4(0, globalKeyCountsB.xyz) + crossLanePrefixSumUpper + previousBlockKeyCountsB; | |
BlockToGlobalKeyOffsetsTexture[uint2(groupID.x, threadID.y + WAVE_SIZE)] = globalKeyStartIndexB - blockKeyStartIndexB; | |
} | |
/* | |
* ----------------------------------------------------------------------------------------------------------- | |
*/ | |
groupshared uint countsOrSortedData[1024 * 2]; | |
[numthreads(1024, 1, 1)] | |
void FinalSort(uint3 threadID : SV_GroupThreadID, uint3 groupID : SV_GroupID) | |
{ | |
uint laneIndex = WaveGetLaneIndex(); //the index of this thread in its wavefront | |
uint waveIndex = threadID.x / WAVE_SIZE; //the index of this wavefront in the group | |
uint globalIndex = threadID.x + (groupID.x * 1024); //the index this thread is loading from the global array | |
uint rawSortKey = KeyInputBuffer[min(globalIndex, _NumElements - 1)]; | |
uint sortPayload = 0; | |
if (_ShouldSortPayload) //only load a payload if we are actually sorting it | |
{ | |
sortPayload = PayloadInputBuffer[min(globalIndex, _NumElements - 1)]; | |
} | |
//do a local sort for this block in two stages, the lower 4 bits of the key and then the upper 4 bits | |
[unroll] | |
for (int subsortFirstBit = 0; subsortFirstBit < 8; subsortFirstBit += 4) | |
{ | |
if (subsortFirstBit != 0) | |
{ | |
GroupMemoryBarrierWithGroupSync(); | |
} | |
uint sortKey = (globalIndex < _NumElements) ? ((rawSortKey >> (_FirstBitToSort + subsortFirstBit)) & 0x0F) : 0x0F; | |
uint internalSortIndex = 0; | |
//count keys in the wave | |
{ | |
uint equalsKeyMask = ~0; | |
uint equalsLaneMask = ~0; | |
uint lessThanLaneMask = 0; | |
uint laneSortValue = laneIndex & 0x0F; | |
for (int bit = 0; bit < 4; bit++) | |
{ | |
bool isBitSet = sortKey & (1 << bit); | |
uint bitSetMask = WaveActiveBallot(isBitSet).x; | |
equalsKeyMask &= isBitSet ? bitSetMask : ~bitSetMask; | |
if (laneSortValue & (1 << bit)) | |
{ | |
lessThanLaneMask |= ~bitSetMask; | |
equalsLaneMask &= bitSetMask; | |
} | |
else | |
{ | |
lessThanLaneMask &= ~bitSetMask; | |
equalsLaneMask &= ~bitSetMask; | |
} | |
} | |
//count the number of lanes before this one with the same key | |
internalSortIndex = countbits((equalsKeyMask << (HIGHEST_LANE - laneIndex)) << 1); | |
//first half of the wave outputs count of keys equal to its value, second half outputs count of keys less than its value | |
uint countToOutput = (laneIndex < HALF_WAVE_SIZE) ? countbits(equalsLaneMask) : countbits(lessThanLaneMask); | |
uint writeIndex = laneIndex + (waveIndex * WAVE_SIZE_PLUS_PAD); | |
countsOrSortedData[writeIndex] = countToOutput; | |
} | |
GroupMemoryBarrierWithGroupSync(); | |
//calculate the prefix sums and count of smaller keys for each of the 16 keys | |
{ | |
uint readIndex = (laneIndex * WAVE_SIZE_PLUS_PAD) + waveIndex; | |
uint waveCounts = countsOrSortedData[readIndex]; | |
[branch] | |
if (waveIndex < 16) //first 16 waves are calculating prefix sums for each key | |
{ | |
uint prefixSumForKey = WavePrefixSum(waveCounts); | |
countsOrSortedData[readIndex] = prefixSumForKey; | |
} | |
else //last 16 waves are calculating starting index for each key | |
{ | |
uint offsetForKey = WaveActiveSum(waveCounts); | |
if (WaveIsFirstLane()) | |
{ | |
countsOrSortedData[(waveIndex - 16) + (WAVE_SIZE * WAVE_SIZE_PLUS_PAD)] = offsetForKey; | |
} | |
} | |
} | |
GroupMemoryBarrierWithGroupSync(); | |
//get the sorted index and then scatter into LDS | |
{ | |
uint readIndex = sortKey + (waveIndex * WAVE_SIZE_PLUS_PAD); | |
internalSortIndex += countsOrSortedData[readIndex]; | |
internalSortIndex += countsOrSortedData[sortKey + (WAVE_SIZE * WAVE_SIZE_PLUS_PAD)]; | |
GroupMemoryBarrierWithGroupSync(); | |
countsOrSortedData[internalSortIndex] = rawSortKey; | |
if (_ShouldSortPayload) | |
{ | |
countsOrSortedData[internalSortIndex + 1024] = sortPayload; | |
} | |
} | |
GroupMemoryBarrierWithGroupSync(); | |
//read the sorted data out of LDS | |
{ | |
rawSortKey = countsOrSortedData[threadID.x]; | |
if (_ShouldSortPayload) | |
{ | |
sortPayload = countsOrSortedData[threadID.x + 1024]; | |
} | |
} | |
} | |
if (globalIndex < _NumElements) | |
{ | |
//load sorted data | |
uint sortKey = (rawSortKey >> _FirstBitToSort) & 0xFF; | |
int finalSortIndex = (int) threadID.x + BlockToGlobalKeyOffsetsTexture[uint2(groupID.x, sortKey / 4)][sortKey % 4]; | |
KeyOutputBuffer[finalSortIndex] = rawSortKey; | |
if (_ShouldSortPayload) | |
{ | |
PayloadOutputBuffer[finalSortIndex] = sortPayload; | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using UnityEngine; | |
using System; | |
//this setup class is made for Unity but the shader will work in any engine that supports D3D12 and HLSL SM 6.0 | |
public class PrefixSorterSetup : MonoBehaviour | |
{ | |
static readonly int maxElements = 1024 * 1024 * 8; | |
public ComputeShader m_sortShader; | |
public int[] m_tests = { 1048576 }; | |
public bool m_shouldSortPayload = true; | |
public string m_debugOutputElements = ""; | |
int m_testSizeIndex = 0; | |
int m_numElements = -1; | |
int m_countTotalsKernel; | |
int m_blockPostfixKernel; | |
int m_calculateOffsetsKernel; | |
int m_finalSortKernel; | |
uint[] m_sortingKeys; | |
uint[] m_sortingPayload; | |
ComputeBuffer m_keysBufferA; | |
ComputeBuffer m_keysBufferB; | |
ComputeBuffer m_payloadBufferA; | |
ComputeBuffer m_payloadBufferB; | |
RenderTexture m_perBlockKeyCountsTexture; | |
RenderTexture m_blockToGlobalKeyOffsetsTexture; | |
private TextMesh m_debugDisplayText; | |
// Start is called before the first frame update | |
void Start() | |
{ | |
m_debugDisplayText = GetComponent<TextMesh>(); | |
SetupComputeShader(); | |
ProcessControlsAndEditorSettings(); //sets up resources | |
} | |
// Update is called once per frame | |
void Update() | |
{ | |
ProcessControlsAndEditorSettings(); | |
if (DoSort()) | |
{ | |
if (Input.GetKeyDown(KeyCode.Space)) | |
{ | |
ValidateKeys(); | |
if (m_shouldSortPayload) | |
{ | |
ValidatePayload(); | |
} | |
} | |
} | |
else | |
{ | |
m_countTotalsKernel = -1; | |
m_blockPostfixKernel = -1; | |
m_calculateOffsetsKernel = -1; | |
m_finalSortKernel = -1; | |
SetupComputeShader(); | |
BuildResources(); | |
} | |
if (m_debugDisplayText != null) | |
{ | |
if (m_sortShader == null) | |
{ | |
m_debugDisplayText.text = "SHADER NOT SET!"; | |
} | |
else | |
{ | |
m_debugDisplayText.text = "Elements: " + m_numElements.ToString() + "\nSortring "; | |
m_debugDisplayText.text += m_shouldSortPayload ? "Keys and Payload" : "Keys Only"; | |
} | |
} | |
} | |
void ProcessControlsAndEditorSettings() | |
{ | |
//keyboard controls | |
if (Input.GetKeyDown(KeyCode.Q)) | |
{ | |
m_testSizeIndex--; | |
} | |
if (Input.GetKeyDown(KeyCode.E)) | |
{ | |
m_testSizeIndex++; | |
} | |
if (Input.GetKeyDown(KeyCode.R)) | |
{ | |
m_shouldSortPayload = !m_shouldSortPayload; | |
} | |
//make sure there is at least 1 valid test size | |
if (m_tests.Length == 0) | |
{ | |
m_tests = new int[1]; | |
m_tests[0] = 1048576; | |
} | |
//wrap test index to vaild range | |
if(m_testSizeIndex < 0) | |
{ | |
m_testSizeIndex = m_tests.Length - 1; | |
} | |
if (m_testSizeIndex >= m_tests.Length) | |
{ | |
m_testSizeIndex = 0; | |
} | |
//pick the number of elements and clamp it | |
int newNumElements = m_tests[m_testSizeIndex]; | |
if (newNumElements > maxElements) | |
{ | |
newNumElements = maxElements; | |
} | |
else if (newNumElements < 1) | |
{ | |
newNumElements = 1; | |
} | |
//rebuild resources if number of elements changed | |
if(m_numElements != newNumElements) | |
{ | |
m_numElements = newNumElements; | |
BuildResources(); | |
} | |
} | |
void SetupComputeShader() | |
{ | |
//check if the shader exists | |
if (m_sortShader == null) | |
{ | |
return; | |
} | |
m_countTotalsKernel = m_sortShader.FindKernel("CountTotalsInBlock"); | |
m_blockPostfixKernel = m_sortShader.FindKernel("BlockCountPostfixSum"); | |
m_calculateOffsetsKernel = m_sortShader.FindKernel("CalculateOffsetsForEachKey"); | |
m_finalSortKernel = m_sortShader.FindKernel("FinalSort"); | |
} | |
void BuildResources() | |
{ | |
if(m_sortShader == null) | |
{ | |
return; | |
} | |
//create an unsorted array of values | |
m_sortingKeys = new uint[m_numElements]; | |
m_sortingPayload = new uint[m_numElements]; | |
for (uint i = 0; i < m_numElements; i++) | |
{ | |
m_sortingKeys[i] = (uint)UnityEngine.Random.Range(int.MinValue, int.MaxValue); | |
m_sortingPayload[i] = i; | |
} | |
//create the buffers | |
if (m_keysBufferA != null) { m_keysBufferA.Release(); } | |
m_keysBufferA = new ComputeBuffer(m_numElements, sizeof(int)); | |
if (m_keysBufferB != null) { m_keysBufferB.Release(); } | |
m_keysBufferB = new ComputeBuffer(m_numElements, sizeof(int)); | |
if (m_payloadBufferA != null) { m_payloadBufferA.Release(); } | |
m_payloadBufferA = new ComputeBuffer(m_numElements, sizeof(int)); | |
if (m_payloadBufferB != null) { m_payloadBufferB.Release(); } | |
m_payloadBufferB = new ComputeBuffer(m_numElements, sizeof(int)); | |
//create the textures | |
int numBlocks = Mathf.CeilToInt(m_numElements / 1024.0f); | |
RenderTextureDescriptor groupTotalsTexDesc = new RenderTextureDescriptor(numBlocks, 64, RenderTextureFormat.ARGBInt, 0); | |
groupTotalsTexDesc.enableRandomWrite = true; | |
if (m_perBlockKeyCountsTexture != null) { m_perBlockKeyCountsTexture.Release(); } | |
m_perBlockKeyCountsTexture = new RenderTexture(groupTotalsTexDesc); | |
m_perBlockKeyCountsTexture.Create(); | |
m_sortShader.SetTexture(m_countTotalsKernel, "PerBlockKeyCountsTexture", m_perBlockKeyCountsTexture, 0); | |
m_sortShader.SetTexture(m_blockPostfixKernel, "PerBlockKeyCountsTexture", m_perBlockKeyCountsTexture, 0); | |
m_sortShader.SetTexture(m_calculateOffsetsKernel, "PerBlockKeyCountsTexture", m_perBlockKeyCountsTexture, 0); | |
if (m_blockToGlobalKeyOffsetsTexture != null) { m_blockToGlobalKeyOffsetsTexture.Release(); } | |
m_blockToGlobalKeyOffsetsTexture = new RenderTexture(groupTotalsTexDesc); | |
m_blockToGlobalKeyOffsetsTexture.Create(); | |
m_sortShader.SetTexture(m_countTotalsKernel, "BlockToGlobalKeyOffsetsTexture", m_blockToGlobalKeyOffsetsTexture, 0); | |
m_sortShader.SetTexture(m_calculateOffsetsKernel, "BlockToGlobalKeyOffsetsTexture", m_blockToGlobalKeyOffsetsTexture, 0); | |
m_sortShader.SetTexture(m_finalSortKernel, "BlockToGlobalKeyOffsetsTexture", m_blockToGlobalKeyOffsetsTexture, 0); | |
} | |
bool DoSort() | |
{ | |
if (m_sortShader == null || m_numElements < 1 || m_numElements > maxElements | |
|| m_countTotalsKernel < 0 || m_blockPostfixKernel < 0 | |
|| m_calculateOffsetsKernel < 0 || m_finalSortKernel < 0) | |
{ | |
return false; | |
} | |
m_keysBufferA.SetData(m_sortingKeys, 0, 0, m_numElements); | |
m_payloadBufferA.SetData(m_sortingPayload, 0, 0, m_numElements); | |
int numBlocks = Mathf.CeilToInt(m_numElements / 1024.0f); | |
m_sortShader.SetInt("_NumElements", m_numElements); | |
m_sortShader.SetInt("_NumBlocks", numBlocks); | |
m_sortShader.SetBool("_ShouldSortPayload", m_shouldSortPayload); | |
//sorting is done four iterations each sorting 1 byte (from lowest to highest) | |
for (int i = 0; i < 4; i++) | |
{ | |
m_sortShader.SetInt("_FirstBitToSort", i * 8); | |
//flip the buffers every other sort | |
if((i % 2) == 0) | |
{ | |
m_sortShader.SetBuffer(m_countTotalsKernel, "KeyInputBuffer", m_keysBufferA); | |
m_sortShader.SetBuffer(m_finalSortKernel, "KeyInputBuffer", m_keysBufferA); | |
m_sortShader.SetBuffer(m_finalSortKernel, "KeyOutputBuffer", m_keysBufferB); | |
m_sortShader.SetBuffer(m_finalSortKernel, "PayloadInputBuffer", m_payloadBufferA); | |
m_sortShader.SetBuffer(m_finalSortKernel, "PayloadOutputBuffer", m_payloadBufferB); | |
} | |
else | |
{ | |
m_sortShader.SetBuffer(m_countTotalsKernel, "KeyInputBuffer", m_keysBufferB); | |
m_sortShader.SetBuffer(m_finalSortKernel, "KeyInputBuffer", m_keysBufferB); | |
m_sortShader.SetBuffer(m_finalSortKernel, "KeyOutputBuffer", m_keysBufferA); | |
m_sortShader.SetBuffer(m_finalSortKernel, "PayloadInputBuffer", m_payloadBufferB); | |
m_sortShader.SetBuffer(m_finalSortKernel, "PayloadOutputBuffer", m_payloadBufferA); | |
} | |
m_sortShader.Dispatch(m_countTotalsKernel, numBlocks, 1, 1); | |
m_sortShader.Dispatch(m_blockPostfixKernel, 1, 64, 1); | |
m_sortShader.Dispatch(m_calculateOffsetsKernel, numBlocks, 1, 1); | |
m_sortShader.Dispatch(m_finalSortKernel, numBlocks, 1, 1); | |
} | |
return true; | |
} | |
void ValidateKeys() | |
{ | |
int numToPrint = 10; | |
m_debugOutputElements = ""; | |
uint[] values = new uint[m_numElements]; | |
m_keysBufferA.GetData(values, 0, 0, m_numElements); | |
bool isSorted = true; | |
float smallest = values[0]; | |
int failedOn = -1; | |
for (int i = 0; i < m_numElements; i++) | |
{ | |
float f = values[i]; | |
//print out the first 1024 elements in a debug string | |
if(i < 1024) | |
{ | |
if (i % 32 == 0) | |
{ | |
m_debugOutputElements += "\n"; | |
} | |
if (i % 256 == 0) | |
{ | |
m_debugOutputElements += "\n"; | |
} | |
m_debugOutputElements += f + ", "; | |
} | |
if (f < smallest && isSorted) | |
{ | |
isSorted = false; | |
failedOn = i; | |
} | |
smallest = f; | |
} | |
//print the values surrounding where the sort failed (or just the first values) | |
string output = "Keys: Size = " + m_numElements + " | Sorted = " + isSorted + " | Failed On = " + failedOn + " | Values: "; | |
int startIndex = Math.Max(0, failedOn - 5); | |
for (int i = startIndex; i < startIndex + numToPrint && i < m_numElements; i++) | |
{ | |
output += values[i] + ", "; | |
} | |
Debug.Log(output); | |
} | |
void ValidatePayload() | |
{ | |
int numToPrint = 10; | |
m_debugOutputElements = ""; | |
//payload contains mapping back to original unsorted indices | |
uint[] originalIndices = new uint[m_numElements]; | |
m_payloadBufferA.GetData(originalIndices, 0, 0, m_numElements); | |
bool isSorted = true; | |
float smallest = m_sortingKeys[originalIndices[0]]; | |
int failedOn = -1; | |
for (int i = 0; i < m_numElements; i++) | |
{ | |
float f = m_sortingKeys[originalIndices[i]]; | |
//print out the first 1024 elements in a debug string | |
if (i < 1024) | |
{ | |
if (i % 32 == 0) | |
{ | |
m_debugOutputElements += "\n"; | |
} | |
if (i % 256 == 0) | |
{ | |
m_debugOutputElements += "\n"; | |
} | |
m_debugOutputElements += f + ", "; | |
} | |
if (f < smallest && isSorted) | |
{ | |
isSorted = false; | |
failedOn = i; | |
} | |
smallest = f; | |
} | |
//print the values surrounding where the sort failed (or just the first values) | |
string output = "Payload: Size = " + m_numElements + " | Sorted = " + isSorted + " | Failed On = " + failedOn + " | Values: "; | |
int startIndex = Math.Max(0, failedOn - 5); | |
for (int i = startIndex; i < startIndex + numToPrint && i < m_numElements; i++) | |
{ | |
output += m_sortingKeys[originalIndices[i]] + ", "; | |
} | |
Debug.Log(output); | |
} | |
private void OnDestroy() | |
{ | |
if (m_keysBufferA != null) { m_keysBufferA.Release(); } | |
if (m_keysBufferB != null) { m_keysBufferB.Release(); } | |
if (m_payloadBufferA != null) { m_payloadBufferA.Release(); } | |
if (m_payloadBufferB != null) { m_payloadBufferB.Release(); } | |
if (m_perBlockKeyCountsTexture != null) { m_perBlockKeyCountsTexture.Release(); } | |
if (m_blockToGlobalKeyOffsetsTexture != null) { m_blockToGlobalKeyOffsetsTexture.Release(); } | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment