Last active
August 26, 2024 11:28
-
-
Save DavidBuchanan314/fe7d87548332a34991f7b258962a845d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import zlib | |
import io | |
import sys | |
PNG_MAGIC = b"\x89PNG\r\n\x1a\n" | |
def parse_png_chunk(stream): | |
size = int.from_bytes(stream.read(4), "big") | |
ctype = stream.read(4) | |
body = stream.read(size) | |
csum = int.from_bytes(stream.read(4), "big") | |
assert(zlib.crc32(ctype + body) == csum) | |
return ctype, body | |
def parse_png(stream): | |
magic = stream.read(len(PNG_MAGIC)) | |
assert(magic == PNG_MAGIC) | |
idat = b"" | |
while True: | |
ctype, body = parse_png_chunk(stream) | |
if ctype == b"IEND": | |
break | |
if ctype == b"IDAT": | |
idat += body | |
if ctype == b"IHDR": | |
ihdr = body | |
return ihdr, idat[2:-4] # strip zlib | |
def decompress(raw): | |
d = zlib.decompressobj(wbits=-15) | |
return d.decompress(raw) + d.flush(zlib.Z_FINISH) | |
# TODO: implement rabin-karp algorithm | |
# current implementation is very slow!!! (and bad!!!) | |
class BackrefFinder(): | |
def __init__(self, window_size=2**15): | |
self.window_size = window_size | |
self.buf = b"" | |
def feed(self, data): | |
self.buf += data | |
def find(self, lookahead): | |
window = self.buf[-self.window_size:] | |
if not window: | |
return 0, None | |
if len(lookahead) < 3: | |
return 0, None | |
x = -1 | |
longest = 0 | |
longest_dist = None | |
try: | |
while True: | |
x = (window+lookahead[:2]).index(lookahead[:3], x+1) | |
i = None | |
for i in range(x+3, x+258+1): | |
if (i - x) >= len(lookahead): # can't look ahead any further | |
break | |
if i < len(window): # look back within window | |
if window[i] != lookahead[i - x]: | |
break | |
else: # look back wihthin lookahead (e.g. for RLE) | |
if lookahead[i-len(window)] != lookahead[i - x]: | |
break | |
if i - x >= longest: | |
longest = i - x | |
longest_dist = len(window) - x | |
except ValueError: # I wish there was a better way to handle .index() failing... | |
return longest, longest_dist | |
LENGTHS = [ | |
(0, 3), # 257 | |
(0, 4), | |
(0, 5), | |
(0, 6), | |
(0, 7), | |
(0, 8), | |
(0, 9), | |
(0, 10), | |
(1, 11), | |
(1, 13), # 266 | |
(1, 15), # 267 | |
(1, 17), | |
(2, 19), | |
(2, 23), | |
(2, 27), | |
(2, 31), | |
(3, 35), | |
(3, 43), | |
(3, 51), | |
(3, 59), # 276 | |
(4, 67), # 277 | |
(4, 83), | |
(4, 99), | |
(4, 115), | |
(5, 131), | |
(5, 163), | |
(5, 195), | |
(5, 227), | |
(0, 258), # 285 | |
(None, 259), # does not exist, here to make some code neater... | |
] | |
DISTANCES = [ | |
(0, 1), | |
(0, 2), | |
(0, 3), | |
(0, 4), | |
(1, 5), | |
(1, 7), | |
(2, 9), | |
(2, 13), | |
(3, 17), | |
(3, 25), | |
(4, 33), | |
(4, 49), | |
(5, 65), | |
(5, 97), | |
(6, 129), | |
(6, 193), | |
(7, 257), | |
(7, 385), | |
(8, 513), | |
(8, 769), | |
(9, 1025), | |
(9, 1537), | |
(10, 2049), | |
(10, 3073), | |
(11, 4097), | |
(11, 6145), | |
(12, 8193), | |
(12, 12289), | |
(13, 16385), | |
(13, 24577), | |
(None, 32769), # does not exist, here to make some code neater... | |
] | |
class Decompressor(): | |
def __init__(self, stream, original): | |
self.stream = stream | |
self.orig = original # original decompressed data | |
self.byte = None | |
self.prevbit = 7 | |
self.buf = b"" | |
self.bf = BackrefFinder() | |
self.steg_bytes = b"" | |
self.steg_bit = 0 | |
self.steg_byte = 0 | |
def next_bit(self): | |
if self.prevbit == 7: | |
self.prevbit = 0 | |
self.byte = self.stream.read(1)[0] | |
else: | |
self.prevbit += 1 | |
bit = (self.byte >> self.prevbit) & 1 | |
return bit | |
def write_steg_bit(self, bit): | |
self.steg_byte |= bit << self.steg_bit | |
self.steg_bit += 1 | |
if self.steg_bit == 8: | |
self.steg_bit = 0 | |
self.steg_bytes += bytes([self.steg_byte]) | |
self.steg_byte = 0 | |
def read_data_element(self, nbits): | |
value = 0 | |
for i in range(nbits): | |
value |= self.next_bit() << i | |
return value | |
def read_huffman_bits(self, nbits, prefix=0): | |
value = prefix | |
for _ in range(nbits): | |
value = (value << 1) | self.next_bit() | |
return value | |
def read_huffman_symbol(self): | |
preview = self.read_huffman_bits(5) | |
if 0b00110 <= preview <= 0b10111: | |
return self.read_huffman_bits(3, preview) - 0b0011_0000 + 0 | |
elif 0b11001 <= preview <= 0b11111: | |
return self.read_huffman_bits(4, preview) - 0b1_1001_0000 + 144 | |
elif 0b00000 <= preview <= 0b00101: | |
return self.read_huffman_bits(2, preview) - 0b000_0000 + 256 | |
else: | |
return self.read_huffman_bits(3, preview) - 0b1100_0000 + 280 | |
def read_fixed_huffman_block(self): | |
symbols = [] | |
while True: | |
symbol = self.read_huffman_symbol() | |
if symbol < 0x100: | |
symbols.append(("lit", symbol)) | |
elif symbol == 0x100: | |
break | |
else: | |
ebits, length = LENGTHS[symbol-257] | |
length += self.read_data_element(ebits) | |
ebits, distance = DISTANCES[self.read_huffman_bits(5)] | |
distance += self.read_data_element(ebits) | |
symbols.append(("ref", length, distance)) | |
return symbols | |
def try_recover_bits(self, actual_len, actual_dist, optimal_len, optimal_dist): | |
if optimal_len == 0: | |
return # nothing | |
if optimal_len < 6: | |
self.write_steg_bit((optimal_len != actual_len) & 1) | |
else: | |
delta = optimal_len - actual_len | |
self.write_steg_bit(delta >> 1) | |
self.write_steg_bit(delta & 1) | |
def process_symbols(self, symbols): | |
for symbol in symbols: | |
longest, longest_dist = self.bf.find(self.orig[len(self.buf):]) | |
if symbol[0] == "lit": | |
new = bytes([symbol[1]]) | |
self.bf.feed(new) | |
self.buf += new | |
self.try_recover_bits(0, None, longest, longest_dist) | |
elif symbol[0] == "ref": | |
length, distance = symbol[1:3] | |
self.try_recover_bits(length, distance, longest, longest_dist) | |
news = b"" | |
for _ in range(length): | |
new = bytes([self.buf[-distance]]) | |
news += new | |
self.buf += new | |
self.bf.feed(news) | |
else: | |
raise Exception("unexpected symbol") | |
def read_block(self): | |
bfinal = self.next_bit() | |
btype = self.read_data_element(2) | |
if btype == 0b00: | |
while self.prevbit != 7: | |
assert(self.next_bit() == 0) | |
size = self.read_data_element(16) | |
notsize = self.read_data_element(16) | |
assert(size == notsize ^ 0xffff) | |
self.buf += self.stream.read(size) | |
elif btype == 0b01: | |
self.process_symbols(self.read_fixed_huffman_block()) | |
elif btype == 0b10: | |
raise Exception("not implemented") | |
#self.process_symbols(self.read_dynamic_huffman_block()) | |
else: | |
raise Exception("not implemented") | |
return bfinal | |
def decompress(self): | |
while self.read_block() != 1: | |
pass | |
return self.buf | |
def steg_unpack(image): | |
ihdr, idat = parse_png(image) | |
d = Decompressor(io.BytesIO(idat), decompress(idat)) | |
d.decompress() | |
return d.steg_bytes | |
if len(sys.argv) != 3: | |
print(f"USAGE: {sys.argv[0]} input.png output.whatever") | |
exit() | |
open(sys.argv[2], "wb").write(steg_unpack(open(sys.argv[1], "rb"))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment