Skip to content

Instantly share code, notes, and snippets.

@grafuls
Created July 12, 2022 12:05
Show Gist options
  • Save grafuls/cac9f1f2a77e6790e3ed397091af732e to your computer and use it in GitHub Desktop.
Save grafuls/cac9f1f2a77e6790e3ed397091af732e to your computer and use it in GitHub Desktop.
import zlib
import struct
PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n"
class Png:
def __init__(self, file):
self.file = file
with open(file, "rb") as buffer:
self.buffer = buffer
self.chunks = self.read_chunks()
self.ihdr = self.read_ihdr()
self.idat = self.read_idat()
(
self.width,
self.height,
self.bitd,
self.colort,
self.compm,
self.filterm,
self.interlacem,
) = struct.unpack(">IIBBBBB", self.ihdr)
self.bytes_per_pixel = 4
self.stride = self.width * self.bytes_per_pixel
self.recon = self.read_recon()
self.validate()
def validate(self):
with open(self.file, "rb") as _f:
if _f.read(len(PNG_SIGNATURE)) != PNG_SIGNATURE:
raise Exception("Invalid PNG Signature")
if self.compm != 0:
raise Exception("invalid compression method")
if self.filterm != 0:
raise Exception("invalid filter method")
if self.colort != 6:
raise Exception("we only support truecolor with alpha")
if self.bitd != 8:
raise Exception("we only support a bit depth of 8")
if self.interlacem != 0:
raise Exception("we only support no interlacing")
@staticmethod
def read_chunk(_f):
header_buffer = _f.read(8)
chunk_length, chunk_type = struct.unpack(">I4s", header_buffer)
chunk_data = _f.read(chunk_length)
crc_buffer = _f.read(4)
(chunk_expected_crc,) = struct.unpack(">I", crc_buffer)
chunk_actual_crc = zlib.crc32(
chunk_data, zlib.crc32(struct.pack(">4s", chunk_type))
)
if chunk_expected_crc != chunk_actual_crc:
raise Exception("chunk checksum failed")
return chunk_type, chunk_data
def read_chunks(self):
chunks = []
with open(self.file, "rb") as _f:
# read the png header to clear the buffer of it
_ = _f.read(len(PNG_SIGNATURE))
while True:
chunk_type, chunk_data = self.read_chunk(_f)
chunks.append((chunk_type, chunk_data))
if chunk_type == b"IEND":
break
return chunks
def read_ihdr(self):
_, ihdr = self.chunks[0]
return ihdr
def read_idat(self):
idat = b"".join(
chunk_data
for chunk_type, chunk_data in self.chunks
if chunk_type == b"IDAT"
)
idat = zlib.decompress(idat)
return idat
@staticmethod
def paeth_predictor(a, b, c):
p = a + b - c
pa = abs(p - a)
pb = abs(p - b)
pc = abs(p - c)
if pa <= pb and pa <= pc:
pr = a
elif pb <= pc:
pr = b
else:
pr = c
return pr
def recon_a(self, r, c):
return self.recon[r * self.stride + c - self.bytes_per_pixel] if c >= self.bytes_per_pixel else 0
def recon_b(self, r, c):
return self.recon[(r - 1) * self.stride + c] if r > 0 else 0
def recon_c(self, r, c):
return (
self.recon[(r - 1) * self.stride + c - self.bytes_per_pixel]
if r > 0 and c >= self.bytes_per_pixel
else 0
)
def read_recon(self):
recon = []
i = 0
for r in range(self.height): # for each scanline
try:
filter_type = self.idat[i] # first byte of scanline is filter type
except IndexError:
break
i += 1
for c in range(self.stride): # for each byte in scanline
try:
filt_x = self.idat[i]
except IndexError:
break
i += 1
if filter_type == 0: # None
recon_x = filt_x
elif filter_type == 1: # Sub
recon_x = filt_x + self.recon_a(r, c)
elif filter_type == 2: # Up
recon_x = filt_x + self.recon_b(r, c)
elif filter_type == 3: # Average
recon_x = filt_x + (self.recon_a(r, c) + self.recon_b(r, c)) // 2
elif filter_type == 4: # Paeth
recon_x = filt_x + self.paeth_predictor(
self.recon_a(r, c), self.recon_b(r, c), self.recon_c(r, c)
)
else:
raise Exception("unknown filter type: " + str(filter_type))
recon.append(recon_x & 0xFF) # truncation to byte
return recon
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment