Last active
June 27, 2024 20:05
-
-
Save dsevero/8e7c38b44953964d3b9873b6bd96d9b2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
This file is heavily inspired by https://github.com/j-towns/ans-notes/blob/master/rans.py | |
We describe a variant of bits-back coding called BB-Huffman. This file is meant | |
purely as an educational tool and is in no way optimized. The goals are to | |
1. illustrate how bits-back coding can be used with Huffman-like codecs; | |
2. and how the cost of using a selector to toggle between codebooks can be greatly reduced. | |
- Symbols are integers between 0 and 2; | |
- Selectors are integers between 0 and 2. | |
Q, P and p are codebooks. All codewords are prefix-free. | |
- Q encodes the selectors with 2 sets of codewords based on the parity of the symbol; | |
- P encodes the symbol with 3 sets of codewords based on the value of the selector; | |
- p is an alternative (i.e. prior) codebook for the selectors, which uses a fixed set of codewords. | |
Note that Q, P and p can be substituted by any entropy coder that outputs explicit codewords. | |
Symbol parity is used to simulate the symbol-dependency of Q. | |
''' | |
from random import choices, seed | |
from math import log2 | |
def encode(bitstring: str, codeword: str): | |
'''Encodes codeword onto bitstring''' | |
return codeword + comp | |
def decode(bitstring: str, codebook: 'List[str]'): | |
'''Reads off bits from bitstring until a codeword is matched | |
in the codebook. Only works if the codebook is prefix free.''' | |
for i in range(1, len(bitstring)+1): | |
for j, p in enumerate(codebook): | |
if bitstring[:i] == p: | |
return j, bitstring[i:] | |
# For reproducibility | |
seed(1337) | |
# Symbols are drawn i.i.d. between 0 and 2 with probability | |
# proportional to the source. | |
source = [1/2, 1/4, 1/4] | |
N = len(source) | |
symbols: 'List[int]' = choices(range(N), source, k=1_000) | |
# Codebooks. | |
# Encodes symbols given a selector (i.e. conditional likelihood) | |
# Note that the lengths perfectly match the source distribution, | |
# and hence the entropy is achievable. | |
P = [['0', '10', '11'], | |
['0', '10', '11'], | |
['1', '01', '00']] | |
# Encodes selectors given a symbol (i.e. approximate posterior). | |
Q = [['0', '10', '11'], | |
['10','11', '0']] | |
M = len(Q) | |
# Encodes selectors, without conditioning on a symbol (i.e. prior). | |
p = ['0', '10', '11'] | |
# Compression. | |
# Initial bits required to start the bits-back chaining. | |
initial_comp = '10' | |
comp = initial_comp | |
for s in symbols: | |
z, comp = decode(comp, Q[s % 2]) # decode a selector, based on the symbol parity | |
comp = encode(comp, P[z][s]) # encode the symbol, based on the selector | |
comp = encode(comp, p[z]) # encode the selector | |
# During compression, note that if the codewords p[z] and Q[s % 2][z] have the same length, | |
# then there is no increase in message length due to the use of a selector. This doesn't | |
# happen for every s % 2 == 1, hence we pay a small penalty. The final increase in | |
# code-length is P[z][s] + p[z] - Q[s % 2][z], which becomes the negative evidence lower | |
# bound (NELBO) when code-lengths L are converted into their implied probabilities 2**(-L). | |
# Calculate metrics | |
message_length = len(comp)/len(symbols) | |
message_entropy = -sum(log2(source[s]) for s in symbols)/len(symbols) | |
entropy = -sum(source[s]*log2(source[s]) for s in range(N)) | |
# Decompression | |
# Proceeds in the exact opposite order of compression. | |
decomp = list() | |
while comp != initial_comp: | |
z, comp = decode(comp, p) # decode the selector | |
s, comp = decode(comp, P[z]) # decode the symbol, based on the selector | |
comp = encode(comp, Q[s % 2][z]) # bits-back step: knowing the selector, get back the bits of Q | |
decomp.insert(0, s) | |
# Check if decoding was done properly. | |
assert decomp == symbols | |
# Difference in message length and message entropy are the initial bits plus the | |
# divergence (i.e. wrong-code penalty) between Q and p. | |
print(f''' | |
message length: {message_length} | |
message entropy: {message_entropy} | |
source entropy: {entropy} | |
''') | |
# The output should be: | |
# | |
# message length: 1.524 | |
# message entropy: 1.522 | |
# source entropy: 1.5 | |
# | |
# Main take-away: The penalty for using a selector to toggle between codecs can be reduced | |
# significantly using bits-back coding. |
In the first encode
function, should comp
be bitstring
?
Thanks for the comment. Yes comp
and bitstring
are the same. I used comp
as a shorthand for "compressed".
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Feel free to leave a comment with some feedback! This is a rough first draft.