Created
February 25, 2025 23:29
-
-
Save Kyuuhachi/fb72bc68eb498b4a5801096b902f3fe0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import typing as T | |
if T.TYPE_CHECKING: | |
float32: T.TypeAlias = float | |
else: | |
try: | |
from numpy import float32 | |
except ImportError: | |
float32 = float | |
import dataclasses as dc | |
import struct | |
import sys | |
__all__ = ["Reader", "Writer"] | |
A = T.TypeVar("A") | |
R = T.TypeVar("R", bound="Reader") | |
@dc.dataclass(repr=False) | |
class Reader: | |
data: bytes | |
pos: int = 0 | |
def __repr__(self) -> str: | |
return f"{type(self).__name__}({len(self)})" | |
__str__ = __repr__ | |
def __len__(self) -> int: | |
return len(self.data) | |
def __getitem__(self, n: int) -> bytes: | |
v = self.data[self.pos:self.pos+n] | |
if len(v) != n: | |
raise ValueError(f"At 0x{self.pos:04x}: tried to read {n} bytes, but only {len(v)} were available") | |
self.pos += n | |
return v | |
def __iter__(self) -> T.Iterator[int]: | |
end = len(self.data) | |
while self.pos < end: | |
yield self.byte() | |
def byte(self) -> int: | |
v = self.data[self.pos] | |
self.pos += 1 | |
return v | |
def zstr(self) -> bytes: | |
l = self.data[self.pos:].find(0) | |
s = self.data[self.pos:self.pos+l] | |
self.pos += l + 1 | |
return s | |
@property | |
def remaining(self) -> int: | |
return len(self.data) - self.pos | |
def at(self: R, pos: int|None = None) -> R: | |
return dc.replace(self, pos = pos if pos is not None else self.pos) | |
def sub(self: R, n: int) -> R: | |
assert 0 <= n <= self.remaining, (n, self.remaining) | |
data = self.data[self.pos:self.pos+n] | |
self.pos += n | |
return dc.replace(self, pos = 0, data = data) | |
def unpack(self, spec: str) -> tuple[T.Any, ...]: | |
return struct.unpack(spec, self[struct.calcsize(spec)]) | |
def u8 (self) -> int: return self.unpack("B")[0] | |
def u16(self) -> int: return self.unpack("H")[0] | |
def u32(self) -> int: return self.unpack("I")[0] | |
def u64(self) -> int: return self.unpack("Q")[0] | |
def i8 (self) -> int: return self.unpack("b")[0] | |
def i16(self) -> int: return self.unpack("h")[0] | |
def i32(self) -> int: return self.unpack("i")[0] | |
def i64(self) -> int: return self.unpack("q")[0] | |
def f32(self) -> float: return float32(self.unpack("f")[0]) | |
def f64(self) -> float: return self.unpack("d")[0] | |
def check(self, data: bytes): _check(self, lambda: self[len(data)], data) | |
def check_u8 (self, v: int): _check(self, self.u8, v) | |
def check_u16(self, v: int): _check(self, self.u16, v) | |
def check_u32(self, v: int): _check(self, self.u32, v) | |
def check_u64(self, v: int): _check(self, self.u64, v) | |
def check_i8 (self, v: int): _check(self, self.i8, v) | |
def check_i16(self, v: int): _check(self, self.i16, v) | |
def check_i32(self, v: int): _check(self, self.i32, v) | |
def check_i64(self, v: int): _check(self, self.i64, v) | |
def dump(self, *, | |
start=None, lines=None, length=None, end=None, width=None, | |
num=6, encoding="cp932", mark=frozenset(), blank=None, | |
skip=False, file=sys.stdout): | |
CSI = "\x1B\x5B" | |
import re | |
escape = re.compile("[\x00-\x1F\x7F\x80-\x9F]") | |
mark = frozenset(mark) | |
if width is None: | |
width = 72 if encoding is None else 48 | |
assert (lines is not None) + (length is not None) <= 1 | |
if lines is not None: length = lines * width | |
oneline = lines == 1 if blank is None else not blank | |
del lines | |
assert (length is not None) + (start is not None) + (end is not None) <= 2 | |
if length is None: | |
if start is None: start = self.pos | |
if end is None: end = len(self) | |
else: | |
if start is None: start = end - length if end is not None else self.pos | |
if end is None: end = start + length | |
del length | |
if start < 0: start = 0 | |
if end > len(self): end = len(self) | |
fmt = "" | |
def format(*f): | |
nonlocal fmt | |
if f != fmt: | |
hl.append(f"{CSI}%sm" % ";".join(map(str, [0, *f]))) | |
fmt = f | |
if start == end: | |
hl = [] | |
format(2) | |
hl.append("--empty--") | |
format() | |
print("".join(hl), file=file) | |
if not oneline: | |
print(file=file) | |
return | |
numwidth = len("%X" % max(0, len(self)-1)) | |
for i in range(start, end, width): | |
hl = [] | |
chunk = bytes(self.data[i:min(i+width, end)]) | |
chunkl = list(chunk) | |
if not chunk: break | |
while len(chunkl) < width: | |
chunkl.append(None) | |
if num: | |
if numwidth < num: | |
format(2,33) | |
hl.append("0" * (num - numwidth)) | |
format(33) | |
hl.append("{:0{}X} ".format(i, numwidth)) | |
for j, b in enumerate(chunkl, i): | |
if b is None: | |
format() | |
hl.append(" ") | |
continue | |
if 0x00 == b : newfmt = [2] | |
elif 0x20 <= b < 0x7F: newfmt = [38,5,10] | |
elif 0xFF == b : newfmt = [38,5,9] | |
else: newfmt = [] | |
format(*newfmt) | |
hl.append(f"{b:02X}") | |
if j+1 == self.pos: | |
format(1,34,7) | |
elif j+1 in mark: | |
format(1,34) | |
hl.append(" •│┿"[(j+1 in mark) + (j+1 == self.pos)*2]) | |
if encoding is not None: | |
format() | |
hl.append(escape.sub(f"{CSI}2m·{CSI}m", chunk.decode(encoding, errors="replace"))) | |
elif j+1 not in mark: | |
hl.pop() # Trailing space | |
format() | |
print("".join(hl), file=file) | |
if not oneline: | |
print(file=file) | |
if skip: | |
self.pos = end | |
def _check(f: Reader, func: T.Callable[[], A], v: A): | |
pos = f.pos | |
w = func() | |
if w != v: | |
f.pos = pos | |
raise ValueError(f"at {pos:X}: got {w}, expected {v}") | |
class Label: pass | |
@dc.dataclass(repr=False) | |
class Writer: | |
data: bytearray | |
thunks: dict[int, tuple[int, T.Callable[[Writer], bytes]]] | |
labels: dict[Label, int] | |
def __init__(self, data: bytes = b""): | |
self.data = bytearray(data) | |
self.thunks = {} | |
self.labels = {} | |
def __repr__(self) -> str: | |
return f"{type(self).__name__}({len(self)})" | |
__str__ = __repr__ | |
def __len__(self) -> int: | |
return len(self.data) | |
def __bytes__(self) -> bytes: | |
for pos, (n, thunk) in self.thunks.items(): | |
b = thunk(self) | |
assert len(b) == n, (b, n) | |
self.data[pos:pos+n] = b | |
self.thunks.clear() | |
return bytes(self.data) | |
def write(self, v: bytes) -> None: | |
self.data += v | |
def __iadd__(self, v: Writer) -> Writer: | |
self.thunks.update((k + len(self.data), v) for k, v in v.thunks.items()) | |
self.labels.update((k, v + len(self.data)) for k, v in v.labels.items()) | |
self.data += v.data | |
return self | |
def __add__(self, v: Writer) -> Writer: | |
w = Writer() | |
w += self | |
w += v | |
return w | |
def __getitem__(self, label: Label) -> int: | |
return self.labels[label] | |
def __setitem__(self, label: Label, pos: int) -> None: | |
self.labels[label] = pos | |
def place(self, label: Label) -> Label: | |
self[label] = len(self) | |
return label | |
def pack(self, spec: str, *args: T.Any) -> None: | |
self.write(struct.pack(spec, *args)) | |
def delay(self, n: int, thunk: T.Callable[[Writer], bytes]) -> None: | |
self.thunks[len(self.data)] = (n, thunk) | |
self.write(bytes(n)) | |
def diff(self, width: int, a: Label, b: Label, offset: int = 0) -> None: | |
self.delay(width, lambda r: int.to_bytes(r[b] - r[a] + offset, width, "little", signed = True)) | |
def u8 (self, v: int) -> None: self.pack("B", v) | |
def u16(self, v: int) -> None: self.pack("H", v) | |
def u32(self, v: int) -> None: self.pack("I", v) | |
def u64(self, v: int) -> None: self.pack("Q", v) | |
def i8 (self, v: int) -> None: self.pack("b", v) | |
def i16(self, v: int) -> None: self.pack("h", v) | |
def i32(self, v: int) -> None: self.pack("i", v) | |
def i64(self, v: int) -> None: self.pack("q", v) | |
def f32(self, v: float) -> None: self.pack("f", v) | |
def f64(self, v: float) -> None: self.pack("d", v) | |
def zstr(self, v: bytes) -> None: | |
assert b"\0" not in v | |
self.write(v) | |
self.u8(0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment