Created
July 8, 2022 21:59
-
-
Save mmalex/fb28b439e12841d9b1563a15adfb8242 to your computer and use it in GitHub Desktop.
attempt to understand range coder by implementing the algorithm on the wikipedia page, of all places...
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <stdio.h> | |
#include <stdint.h> | |
#include <stdlib.h> | |
#include <assert.h> | |
#include <vector> | |
// arithmetic coder/decoder by @mmalex based on the wikipedia page https://en.wikipedia.org/wiki/Range_coding as of July 2022 | |
struct encdec_state { | |
typedef uint16_t artype; | |
const static int topshift = sizeof(artype)*8-8; | |
artype low=0,range=~artype(0); | |
inline void shift() { low<<=8; range<<=8; } | |
// the heart of the algorithm is here: | |
template <typename T> static void normalise(T&self, uint32_t start, uint32_t end, uint32_t total) { // static so we can force inline a polymorphic self.byte() without virtual | |
self.low+=start*self.range; | |
self.range*=end-start; | |
while ((self.low>>topshift)==((self.low+self.range)>>topshift)) { self.byte(); self.shift(); } | |
if (self.range<total) { //printf("underflow range %llu\n", (uint64_t)self.range); | |
while (self.range<total) { self.byte(); self.shift(); } | |
self.range=~self.low; | |
} | |
} | |
}; | |
struct encoder : public encdec_state { | |
std::vector<uint8_t> out; | |
inline void byte() { out.push_back(low>>topshift); } // push a byte to the output | |
inline void encode(uint32_t cdf_start, uint32_t cdf_end, uint32_t cdf_total) { // encode a symbol with the given range of cdf | |
range/=cdf_total; | |
normalise(*this,cdf_start,cdf_end,cdf_total); | |
} | |
inline void flush_encode() { // at the end of the stream, call this | |
artype ofs = artype(1)<<topshift; | |
while (range<ofs) { byte(); shift(); } | |
low+=ofs; | |
byte(); shift(); | |
} | |
}; | |
struct decoder : public encdec_state { | |
artype state=0; | |
const uint8_t *in=nullptr, *end=nullptr; | |
void byte() { state=(state<<8)+ ((in<end)?*in++ : 0); } // pull a byte from the input | |
decoder(const uint8_t *din, const uint8_t *dend) : encdec_state(), in(din), end(dend) { | |
for (int i=0;i<sizeof(artype);++i) byte(); | |
range=~artype(0); | |
} | |
inline uint32_t decode(uint32_t cdf_total) { // call normalise after. returns a value from 0 to total that needs looking up in an inverse cdf | |
range /= cdf_total; | |
return (state - low) / range; | |
} | |
}; | |
int main(int argc, char **argv) { | |
uint32_t cdf[257]={}; | |
std::vector<uint8_t> plaintext; | |
for (int i=0;i<100;++i) { | |
plaintext.push_back('A'+(rand()%54)); | |
cdf[1+plaintext.back()]++; | |
} | |
for (int i=0;i<256;++i) cdf[i+1]+=cdf[i]; | |
uint32_t total=cdf[256]; | |
std::vector<uint8_t> icdf(total); | |
for (int i=0;i<256;++i) | |
for (int j=cdf[i];j<cdf[i+1];++j) icdf[j]=i; | |
encoder enc; | |
for (auto sym : plaintext) { | |
printf("%c", sym); | |
enc.encode(cdf[sym], cdf[sym+1], total); | |
} | |
enc.flush_encode(); | |
printf("\n%d bytes\n",(int)enc.out.size()); | |
// decode it and make sure it round trips | |
decoder dec(enc.out.data(), enc.out.data()+enc.out.size()); | |
for (auto orig_sym : plaintext) { | |
uint32_t k = dec.decode(total); | |
uint32_t sym=icdf[k]; | |
encdec_state::normalise(dec,cdf[sym], cdf[sym+1], total); | |
printf("%c", sym); | |
if (sym!=orig_sym) { | |
printf("decode error, %d vs %d, range %d\n", sym, orig_sym, (int)dec.range); | |
assert(false); | |
} | |
} | |
printf("\n"); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
change artype to uint64_t in practice; its uint16_t here because that stimulates the underflow case more thoroughly :) but wont work well if the cdf total is too high.