Last active
September 17, 2024 21:00
-
-
Save dsevero/7e02d96e079ce44b89ff33d7a1ce1738 to your computer and use it in GitHub Desktop.
Asymmetric Numeral Systems (ANS) codec in pure Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def push(state, symbol, cdf_func, prec): | |
cdf_low, cdf_high = cdf_func(symbol) | |
freq = cdf_high - cdf_low | |
return prec*(state // freq) + (state % freq) + cdf_low | |
def pop(state, icdf_func, cdf_func, prec): | |
cdf_value = state % prec | |
symbol, cdf_low, cdf_high = icdf_func(cdf_value) | |
freq = cdf_high - cdf_low | |
return symbol, freq*(state // prec) + cdf_value - cdf_low |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' Heavily inspired by https://github.com/j-towns/ans-notes | |
''' | |
from math import log2 | |
from functools import reduce | |
from rans import push, pop | |
initial_state = 0 | |
precision = 8 | |
alphabet = [0, 1, 2] | |
pmf = [1/2, 1/4, 1/4] | |
entropy = sum(p*log2(1/p) for p in pmf) | |
# For pmf=[1/2, 1/4, 1/4] at precision=8, the quantized cdf=[0, 4, 6, 8] | |
cdf = reduce(lambda acc,el: acc + [acc[-1] + round(el*precision)], pmf, [0]) | |
# ANS requires these 2 functions. | |
def cdf_func(symbol): | |
''' Function signature is symbol -> (cdf_low, cdf_high). | |
This can be substituted for a more complex model like a neural network''' | |
return cdf[symbol], cdf[symbol+1] | |
def icdf_func(cdf_value): | |
''' Function signature is cdf_value -> (symbol, cdf_low, cdf_high). | |
Finds the symbol where cdf_func(symbol) <= cdf_value < cdf_func(symbol+1) | |
This can be substituted for a more complex model like a neural network''' | |
for symbol in alphabet: | |
cdf_low, cdf_high = cdf_func(symbol) | |
if cdf_low <= cdf_value < cdf_high: | |
return symbol, cdf_low, cdf_high | |
# Some symbols to compress | |
sequence = 100*[2, 0, 0, 1] | |
# Encode | |
state = initial_state | |
for symbol in reversed(sequence): | |
state = push(state, symbol, cdf_func, precision) | |
rate = state.bit_length()/len(sequence) | |
# Decode | |
decoded_sequence = len(sequence)*[None] | |
for i in range(len(sequence)): | |
decoded_sequence[i], state = pop(state, icdf_func, cdf_func, precision) | |
# Sanity checks | |
assert decoded_sequence == sequence | |
assert (rate - entropy) < 0.01 | |
print(f''' | |
- Encoded {len(sequence)} symbols | |
- Rate: {rate} bits/symbol | |
- Entropy: {entropy} bits | |
''') | |
# - Encoded 400 symbols | |
# - Rate: 1.5025 bits/symbol | |
# - Entropy: 1.5 bits | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment