Skip to content

Instantly share code, notes, and snippets.

@colltoaction
Last active October 6, 2015 00:53
Show Gist options
  • Save colltoaction/ba2a4aa5677967bed255 to your computer and use it in GitHub Desktop.
Save colltoaction/ba2a4aa5677967bed255 to your computer and use it in GitHub Desktop.
Adaptive Arithmetic Coder written in Python.
class AdaptiveArithmeticCoder:
FREQ_INIT = 1
def __init__(self, character_space):
self.character_space = character_space
self.count = len(character_space)
self.freqs = [ AdaptiveArithmeticCoder.FREQ_INIT ] * self.count
self.intervals = [ None ] * self.count
self.update_intervals((0, 1))
def update_intervals(self, interval):
a, b = interval
last = a
for i in range(len(self.freqs)):
probability = self.freqs[i] / self.count * (b - a)
self.intervals[i] = (last, last + probability)
last = last + probability
def encode(self, c):
"""Receives a character to encode. Updates the adaptive model."""
index = self.character_space.index(c)
self.freqs[index] += 1
self.count += 1
self.update_intervals(self.intervals[index])
def current_interval(self):
"""Returns the current encoded file."""
a = self.intervals[0][0]
b = self.intervals[-1][1]
return a, b
def current(self):
"""Returns the current encoded file."""
a, b = self.current_interval()
return (a + b) / 2
def current_binary(self):
"""Returns the current encoded file."""
out_s = ""
out = 0
n = -1
a, b = self.current_interval()
while not a < out < b:
if out + 2 ** n > b:
out_s += "0"
else:
out += 2 ** n
out_s += "1"
n -= 1
return out_s
if __name__ == '__main__':
import fileinput
import argparse
import os
parser = argparse.ArgumentParser(description='Dynamically encodes a string received through standard input.')
parser.add_argument('-interval', action="store_true", help='Print the floating interval rather than the binary coding.')
parser.add_argument('-character_space', type=str, required=True, help='All possible characters in the input. e.g. "ABCD".')
args = parser.parse_args()
coder = AdaptiveArithmeticCoder(args.character_space)
for line in os.sys.stdin:
line = line.rstrip('\n')
for c in line:
coder.encode(c)
if (args.interval):
print(coder.current_interval())
else:
print(coder.current_binary())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment