Skip to content

Instantly share code, notes, and snippets.

@Kyuuhachi
Created February 25, 2025 23:29
Show Gist options
  • Save Kyuuhachi/fb72bc68eb498b4a5801096b902f3fe0 to your computer and use it in GitHub Desktop.
Save Kyuuhachi/fb72bc68eb498b4a5801096b902f3fe0 to your computer and use it in GitHub Desktop.
from __future__ import annotations
import typing as T
if T.TYPE_CHECKING:
float32: T.TypeAlias = float
else:
try:
from numpy import float32
except ImportError:
float32 = float
import dataclasses as dc
import struct
import sys
__all__ = ["Reader", "Writer"]
A = T.TypeVar("A")
R = T.TypeVar("R", bound="Reader")
@dc.dataclass(repr=False)
class Reader:
data: bytes
pos: int = 0
def __repr__(self) -> str:
return f"{type(self).__name__}({len(self)})"
__str__ = __repr__
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, n: int) -> bytes:
v = self.data[self.pos:self.pos+n]
if len(v) != n:
raise ValueError(f"At 0x{self.pos:04x}: tried to read {n} bytes, but only {len(v)} were available")
self.pos += n
return v
def __iter__(self) -> T.Iterator[int]:
end = len(self.data)
while self.pos < end:
yield self.byte()
def byte(self) -> int:
v = self.data[self.pos]
self.pos += 1
return v
def zstr(self) -> bytes:
l = self.data[self.pos:].find(0)
s = self.data[self.pos:self.pos+l]
self.pos += l + 1
return s
@property
def remaining(self) -> int:
return len(self.data) - self.pos
def at(self: R, pos: int|None = None) -> R:
return dc.replace(self, pos = pos if pos is not None else self.pos)
def sub(self: R, n: int) -> R:
assert 0 <= n <= self.remaining, (n, self.remaining)
data = self.data[self.pos:self.pos+n]
self.pos += n
return dc.replace(self, pos = 0, data = data)
def unpack(self, spec: str) -> tuple[T.Any, ...]:
return struct.unpack(spec, self[struct.calcsize(spec)])
def u8 (self) -> int: return self.unpack("B")[0]
def u16(self) -> int: return self.unpack("H")[0]
def u32(self) -> int: return self.unpack("I")[0]
def u64(self) -> int: return self.unpack("Q")[0]
def i8 (self) -> int: return self.unpack("b")[0]
def i16(self) -> int: return self.unpack("h")[0]
def i32(self) -> int: return self.unpack("i")[0]
def i64(self) -> int: return self.unpack("q")[0]
def f32(self) -> float: return float32(self.unpack("f")[0])
def f64(self) -> float: return self.unpack("d")[0]
def check(self, data: bytes): _check(self, lambda: self[len(data)], data)
def check_u8 (self, v: int): _check(self, self.u8, v)
def check_u16(self, v: int): _check(self, self.u16, v)
def check_u32(self, v: int): _check(self, self.u32, v)
def check_u64(self, v: int): _check(self, self.u64, v)
def check_i8 (self, v: int): _check(self, self.i8, v)
def check_i16(self, v: int): _check(self, self.i16, v)
def check_i32(self, v: int): _check(self, self.i32, v)
def check_i64(self, v: int): _check(self, self.i64, v)
def dump(self, *,
start=None, lines=None, length=None, end=None, width=None,
num=6, encoding="cp932", mark=frozenset(), blank=None,
skip=False, file=sys.stdout):
CSI = "\x1B\x5B"
import re
escape = re.compile("[\x00-\x1F\x7F\x80-\x9F]")
mark = frozenset(mark)
if width is None:
width = 72 if encoding is None else 48
assert (lines is not None) + (length is not None) <= 1
if lines is not None: length = lines * width
oneline = lines == 1 if blank is None else not blank
del lines
assert (length is not None) + (start is not None) + (end is not None) <= 2
if length is None:
if start is None: start = self.pos
if end is None: end = len(self)
else:
if start is None: start = end - length if end is not None else self.pos
if end is None: end = start + length
del length
if start < 0: start = 0
if end > len(self): end = len(self)
fmt = ""
def format(*f):
nonlocal fmt
if f != fmt:
hl.append(f"{CSI}%sm" % ";".join(map(str, [0, *f])))
fmt = f
if start == end:
hl = []
format(2)
hl.append("--empty--")
format()
print("".join(hl), file=file)
if not oneline:
print(file=file)
return
numwidth = len("%X" % max(0, len(self)-1))
for i in range(start, end, width):
hl = []
chunk = bytes(self.data[i:min(i+width, end)])
chunkl = list(chunk)
if not chunk: break
while len(chunkl) < width:
chunkl.append(None)
if num:
if numwidth < num:
format(2,33)
hl.append("0" * (num - numwidth))
format(33)
hl.append("{:0{}X} ".format(i, numwidth))
for j, b in enumerate(chunkl, i):
if b is None:
format()
hl.append(" ")
continue
if 0x00 == b : newfmt = [2]
elif 0x20 <= b < 0x7F: newfmt = [38,5,10]
elif 0xFF == b : newfmt = [38,5,9]
else: newfmt = []
format(*newfmt)
hl.append(f"{b:02X}")
if j+1 == self.pos:
format(1,34,7)
elif j+1 in mark:
format(1,34)
hl.append(" •│┿"[(j+1 in mark) + (j+1 == self.pos)*2])
if encoding is not None:
format()
hl.append(escape.sub(f"{CSI}2m·{CSI}m", chunk.decode(encoding, errors="replace")))
elif j+1 not in mark:
hl.pop() # Trailing space
format()
print("".join(hl), file=file)
if not oneline:
print(file=file)
if skip:
self.pos = end
def _check(f: Reader, func: T.Callable[[], A], v: A):
pos = f.pos
w = func()
if w != v:
f.pos = pos
raise ValueError(f"at {pos:X}: got {w}, expected {v}")
class Label: pass
@dc.dataclass(repr=False)
class Writer:
data: bytearray
thunks: dict[int, tuple[int, T.Callable[[Writer], bytes]]]
labels: dict[Label, int]
def __init__(self, data: bytes = b""):
self.data = bytearray(data)
self.thunks = {}
self.labels = {}
def __repr__(self) -> str:
return f"{type(self).__name__}({len(self)})"
__str__ = __repr__
def __len__(self) -> int:
return len(self.data)
def __bytes__(self) -> bytes:
for pos, (n, thunk) in self.thunks.items():
b = thunk(self)
assert len(b) == n, (b, n)
self.data[pos:pos+n] = b
self.thunks.clear()
return bytes(self.data)
def write(self, v: bytes) -> None:
self.data += v
def __iadd__(self, v: Writer) -> Writer:
self.thunks.update((k + len(self.data), v) for k, v in v.thunks.items())
self.labels.update((k, v + len(self.data)) for k, v in v.labels.items())
self.data += v.data
return self
def __add__(self, v: Writer) -> Writer:
w = Writer()
w += self
w += v
return w
def __getitem__(self, label: Label) -> int:
return self.labels[label]
def __setitem__(self, label: Label, pos: int) -> None:
self.labels[label] = pos
def place(self, label: Label) -> Label:
self[label] = len(self)
return label
def pack(self, spec: str, *args: T.Any) -> None:
self.write(struct.pack(spec, *args))
def delay(self, n: int, thunk: T.Callable[[Writer], bytes]) -> None:
self.thunks[len(self.data)] = (n, thunk)
self.write(bytes(n))
def diff(self, width: int, a: Label, b: Label, offset: int = 0) -> None:
self.delay(width, lambda r: int.to_bytes(r[b] - r[a] + offset, width, "little", signed = True))
def u8 (self, v: int) -> None: self.pack("B", v)
def u16(self, v: int) -> None: self.pack("H", v)
def u32(self, v: int) -> None: self.pack("I", v)
def u64(self, v: int) -> None: self.pack("Q", v)
def i8 (self, v: int) -> None: self.pack("b", v)
def i16(self, v: int) -> None: self.pack("h", v)
def i32(self, v: int) -> None: self.pack("i", v)
def i64(self, v: int) -> None: self.pack("q", v)
def f32(self, v: float) -> None: self.pack("f", v)
def f64(self, v: float) -> None: self.pack("d", v)
def zstr(self, v: bytes) -> None:
assert b"\0" not in v
self.write(v)
self.u8(0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment