Skip to content

Instantly share code, notes, and snippets.

@TimDettmers
Last active September 30, 2021 15:45
Show Gist options
  • Save TimDettmers/f731c917117e00979852d7dbc1e76cb8 to your computer and use it in GitHub Desktop.
Save TimDettmers/f731c917117e00979852d7dbc1e76cb8 to your computer and use it in GitHub Desktop.
Calculate Huffman compression ratio with bitsandbytes
import torch
import bitsandbytes as bnb
from heapq import heappush, heappop, heapify
a = torch.normal(0, 0.5, size=(1024, 1024),device='cuda')
def get_compression(x:torch.Tensor)->float:
"""Yields the compression rate of Huffman Coding"""
assert x.device.type == 'cuda'
assert x.dtype in [torch.float32, torch.float16]
C, S = bnb.functional.quantize_blockwise(x)
val, counts = torch.unique(C.int(), return_counts=True)
symb2freq = {}
for i, (c, count) in enumerate(zip(val, counts)):
symb2freq[c.item()] = count.item()
huff = encode(symb2freq)
total_bits = 0
for p in huff:
total_bits += len(p[1])*symb2freq[p[0]]
return 1.0-(total_bits/(C.numel()*8))
# taken from: https://rosettacode.org/wiki/Huffman_coding#Python
def encode(symb2freq):
"""Huffman encode the given dict mapping symbols to weights"""
heap = [[wt, [sym, ""]] for sym, wt in symb2freq.items()]
heapify(heap)
while len(heap) > 1:
lo = heappop(heap)
hi = heappop(heap)
for pair in lo[1:]:
pair[1] = '0' + pair[1]
for pair in hi[1:]:
pair[1] = '1' + pair[1]
heappush(heap, [lo[0] + hi[0]] + lo[1:] + hi[1:])
return sorted(heappop(heap)[1:], key=lambda p: (len(p[-1]), p))
print(get_compression(a))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment