Last active
September 30, 2021 15:45
-
-
Save TimDettmers/f731c917117e00979852d7dbc1e76cb8 to your computer and use it in GitHub Desktop.
Calculate Huffman compression ratio with bitsandbytes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import bitsandbytes as bnb | |
from heapq import heappush, heappop, heapify | |
a = torch.normal(0, 0.5, size=(1024, 1024),device='cuda') | |
def get_compression(x:torch.Tensor)->float: | |
"""Yields the compression rate of Huffman Coding""" | |
assert x.device.type == 'cuda' | |
assert x.dtype in [torch.float32, torch.float16] | |
C, S = bnb.functional.quantize_blockwise(x) | |
val, counts = torch.unique(C.int(), return_counts=True) | |
symb2freq = {} | |
for i, (c, count) in enumerate(zip(val, counts)): | |
symb2freq[c.item()] = count.item() | |
huff = encode(symb2freq) | |
total_bits = 0 | |
for p in huff: | |
total_bits += len(p[1])*symb2freq[p[0]] | |
return 1.0-(total_bits/(C.numel()*8)) | |
# taken from: https://rosettacode.org/wiki/Huffman_coding#Python | |
def encode(symb2freq): | |
"""Huffman encode the given dict mapping symbols to weights""" | |
heap = [[wt, [sym, ""]] for sym, wt in symb2freq.items()] | |
heapify(heap) | |
while len(heap) > 1: | |
lo = heappop(heap) | |
hi = heappop(heap) | |
for pair in lo[1:]: | |
pair[1] = '0' + pair[1] | |
for pair in hi[1:]: | |
pair[1] = '1' + pair[1] | |
heappush(heap, [lo[0] + hi[0]] + lo[1:] + hi[1:]) | |
return sorted(heappop(heap)[1:], key=lambda p: (len(p[-1]), p)) | |
print(get_compression(a)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment