Last active
April 16, 2024 22:51
-
-
Save nanokatze/fea4f881c7c5b7b1b6931a92e7771a21 to your computer and use it in GitHub Desktop.
LZ4 decompressor for Vulkan written in Slang
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
typedef uint64_t size_t; | |
typedef uint64_t uintptr_t; | |
struct Queue<T> { | |
// pop side would really benefit from a lock tbh | |
uint32_t head; | |
uint32_t tail; | |
T buf[1]; | |
}; | |
uint myAtomicLoad(uint *p) { | |
uint scope = 1; | |
uint semantics = 0x2002; | |
return spirv_asm { OpAtomicLoad $$uint result $p $scope $semantics }; | |
} | |
uint myAtomicAdd(uint *p, uint value) { | |
uint scope = 1; | |
uint semantics = 0x6006; | |
return spirv_asm { OpAtomicIAdd $$uint result $p $scope $semantics $value }; | |
} | |
uint myAtomicCompareExchange(uint *p, uint comparator, uint value) { | |
uint scope = 1; | |
uint semantics = 0x6006; | |
uint semanticsUnequal = 0x2002; | |
return spirv_asm { OpAtomicCompareExchange $$uint result $p $scope $semantics $semanticsUnequal $value $comparator }; | |
} | |
#define MAX_COPY_JOB_SIZE 8192 // this should be specified by the host in launch params. | |
#define JOB_TYPE_COPY_MEMORY 1 | |
struct Job { | |
int8_t type; | |
uint8_t *dst; | |
uint8_t *src; | |
size_t size; | |
}; | |
/* | |
struct Comp { | |
Queue<Job> jobQueue; | |
}; | |
*/ | |
struct Results { | |
uint64_t consumed; | |
uint64_t produced; | |
uint64_t t0; | |
uint64_t t1; | |
uint32_t spinDown; // there won't be any more jobs. TODO: move into a Comp struct | |
}; | |
struct Push { | |
Queue<Job> *jobQueue; | |
uint8_t *src; | |
size_t srcLen; | |
uint8_t *dst; | |
size_t dstLen; | |
Results *results; | |
}; | |
[[vk::push_constant]] Push push; | |
uint32_t load32_unaligned(uint8_t *p) { | |
return (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16) | ((uint32_t)p[3] << 24); | |
} | |
[shader("compute")] | |
[numthreads(16, 1, 1)] | |
void decompress(uint3 groupID : SV_GroupID, uint32_t index : SV_DispatchThreadID) { | |
// TODO: we probably want only one thread doing this. | |
// TODO: get a ticket with compare exchange instead | |
if (index == 0) { | |
size_t i = 0; | |
size_t j = 0; | |
while (i < push.srcLen) { | |
uint32_t magic = load32_unaligned(push.src + i); | |
if (magic != 0x184d2204) { | |
// halt and report error | |
} | |
i += 4; | |
uint8_t flg = push.src[i]; | |
i++; | |
uint8_t bd = push.src[i]; | |
i++; | |
// TODO: consult block independence as | |
if (flg & (1<<3)) { | |
// content size | |
// | |
// TODO: require this. We don't necessarily *need* this, but if | |
// content size is not present, we'll need to decompress this frame | |
// before we can proceed. | |
i += 8; | |
} | |
if (flg & (1<<0)) { | |
// dictionary ID | |
i += 4; | |
} | |
// header checksum | |
i++; | |
uint32_t endMark = 0; | |
while (true) { | |
uint32_t word = load32_unaligned(push.src + i); | |
if (word == 0x00000000) { | |
endMark = word; | |
break; | |
} | |
i += 4; | |
bool uncompressed = (word & 0x80000000) != 0; | |
uint32_t blockSize = word & 0x7fffffff; | |
// TODO: we need to do more work here to figure out block | |
// uncompressed size and such. | |
if (uncompressed) { | |
// TODO: we might be able to use the entire subgroup to speed | |
// this part up | |
for (size_t k = 0; k < blockSize; k += MAX_COPY_JOB_SIZE) { | |
Job job = {}; | |
job.type = JOB_TYPE_COPY_MEMORY; | |
job.dst = push.dst + j + k; | |
job.src = push.src + i + k; | |
job.size = min(blockSize, MAX_COPY_JOB_SIZE); | |
uint32_t index = myAtomicAdd(&push.jobQueue.head, 1); | |
push.jobQueue.buf[index] = job; | |
} | |
i += blockSize; | |
j += blockSize; | |
} else { | |
// we need to scan the block ourselves and fan out jobs | |
i += blockSize; | |
} | |
if (flg & (1<<4)) { | |
// block checksum | |
i += 4; | |
} | |
} | |
// if (endMark != 0) { | |
// invalid end mark | |
// } | |
i += 4; | |
if (flg & (1 << 2)) { | |
// content checksum | |
i += 4; | |
} | |
// handle only one frame for now | |
break; | |
} | |
push.results.consumed = i; | |
push.results.produced = j; | |
// TODO: should be myAtomicStore | |
myAtomicAdd(&push.results.spinDown, 1); | |
} | |
while (true) { | |
bool quit = false; | |
uint32_t index; | |
if (WaveIsFirstLane()) { | |
uint32_t tail = myAtomicLoad(&push.jobQueue.tail); | |
while (true) { | |
uint32_t head = myAtomicLoad(&push.jobQueue.head); | |
if (tail < head) { | |
uint32_t was = myAtomicCompareExchange(&push.jobQueue.tail, tail, tail+1); | |
if (was == tail) { | |
index = tail; | |
break; | |
} else { | |
tail = was; | |
} | |
} else if (myAtomicLoad(&push.results.spinDown) != 0) { | |
quit = true; | |
break; | |
} | |
} | |
} | |
index = WaveReadLaneFirst(index); | |
if (WaveReadLaneFirst(quit)) | |
return; | |
Job job = push.jobQueue.buf[index]; | |
for (size_t i = WavePrefixSum(1); i < job.size; i += WaveActiveSum(1)) { | |
job.dst[i] = job.src[i]; | |
} | |
} | |
// myAtomicLoad(&push.results.spinDown); | |
} | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
typedef uint64_t size_t; | |
// TODO: pass array of blocks to decompress, instead of params for a single block | |
struct Push { | |
uint8_t *src; | |
size_t srcLen; | |
uint8_t *dst; | |
size_t dstLen; | |
}; | |
[[vk::push_constant]] Push push; | |
uint subgroupBallotFindLSB(uint4 value) { | |
return spirv_asm { | |
OpCapability GroupNonUniformBallot; | |
OpGroupNonUniformBallotFindLSB $$uint result Subgroup $value | |
}; | |
} | |
void subgroupMemcpy(uint8_t *dst, uint8_t *src, size_t n) { | |
// TODO: use a bigger unit when possible | |
for (size_t i = WavePrefixSum(1); i < n; i += WaveActiveSum(1)) { | |
dst[i] = src[i]; | |
} | |
/* | |
if (!WaveIsFirstLane()) | |
return; | |
for (size_t i = 0; i < n; i++) { | |
dst[i] = src[i]; | |
} | |
*/ | |
} | |
struct ConsumedVarint { | |
uint64_t value; | |
size_t len; | |
}; | |
ConsumedVarint consumeVarint(uint8_t *p, size_t n) { | |
// TODO: put this inside a loop, our subgroup might be narrower than | |
// the length of the run. | |
// TODO: do a scan with 1 instead of WaveGetLaneIndex | |
uint64_t value = 0; | |
size_t i = 0; | |
while (true) { | |
// Speculatively load varint bytes. | |
// | |
// TODO: we should have a separate tunable knob for how far we | |
// should speculate | |
uint8_t byte = 0; | |
if (i + WaveGetLaneIndex() < n) { | |
byte = p[i + WaveGetLaneIndex()]; | |
} | |
uint run = subgroupBallotFindLSB(WaveActiveBallot(byte != 255)) + 1; | |
if (WaveGetLaneIndex() >= run) { | |
byte = 0; | |
} | |
value += WaveActiveSum((size_t)byte); | |
i += run; | |
if (run < WaveGetLaneCount()) | |
break; | |
} | |
/* | |
for (; i < n;) { | |
uint8_t add = p[i++]; | |
value += add; | |
if (add != 255) | |
break; | |
} | |
*/ | |
ConsumedVarint result = {value, i}; | |
return result; | |
} | |
[shader("compute")] | |
[numthreads(16, 1, 1)] // BUG: we just want single-subgroup workgroups | |
void decompressBlock( uint3 what : SV_GroupID /* uint32_t index : SV_DispatchThreadID */) { | |
// Currently we use a subgroup per block, we might want to scale up to | |
// workgroup per block if blocks are big enough in practice | |
if (all(what != uint3(0))) | |
return; | |
size_t i = 0; | |
size_t j = 0; | |
while (i < push.srcLen) { | |
uint8_t token = push.src[i++]; | |
size_t literallength = (size_t)(token >> 4); | |
if ((token >> 4) == 15) { | |
// TODO: we can load this varint speculatively | |
ConsumedVarint wat = consumeVarint(push.src + i, push.srcLen - i); | |
literallength += wat.value; | |
i += wat.len; | |
} | |
subgroupMemcpy(push.dst + j, push.src + i, literallength); | |
i += literallength; | |
j += literallength; | |
if (i == push.srcLen) | |
break; | |
size_t offset = (size_t)((uint16_t)push.src[i+0] | ((uint16_t)push.src[i+1] << 8)); | |
i += 2; | |
size_t matchlength = 4 + (size_t)(token & 0xf); | |
if ((token & 0xf) == 15) { | |
// TODO: we can load this varint speculatively | |
ConsumedVarint wat = consumeVarint(push.src + i, push.srcLen - i); | |
matchlength += wat.value; | |
i += wat.len; | |
} | |
// TODO: do this with the entire subgroup | |
if (WaveIsFirstLane()) { | |
for (size_t k = 0; k < matchlength; k++) { | |
push.dst[j+k] = push.dst[j-offset+k]; | |
} | |
} | |
j += matchlength; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment