Skip to content

Instantly share code, notes, and snippets.

@celsowm
Last active June 19, 2025 03:17
Show Gist options
  • Save celsowm/6706529abd1d1134a0012d81bf0ed256 to your computer and use it in GitHub Desktop.
Save celsowm/6706529abd1d1134a0012d81bf0ed256 to your computer and use it in GitHub Desktop.
python/sglang/srt/weight_loader/gguf_loader.py
from __future__ import annotations
import io
import mmap
import pathlib
import struct
from enum import IntEnum
from typing import Any, BinaryIO, Dict, List, Tuple
import numpy as np
import torch
# Optional tqdm for progress bar
try:
import tqdm
except ImportError:
tqdm = None
# ────────────────────────────
# basic binary helpers
# ────────────────────────────
_MAGIC = b"GGUF"
_HDR = struct.Struct("<4sIQQ") # magic, version, n_tensors, n_meta
_UINT8 = struct.Struct("<B")
_UINT16 = struct.Struct("<H")
_UINT32 = struct.Struct("<I")
_UINT64 = struct.Struct("<Q")
_FP16 = struct.Struct("<e") # little-endian IEEE-754 half
# ────────────────────────────
# GGML / GGUF enums
# ────────────────────────────
class GGMLType(IntEnum):
F32 = 0
F16 = 1
Q4_0 = 2
Q4_1 = 3
Q8_0 = 6
Q8_1 = 7
Q2_K = 10
Q3_K = 11
Q4_K = 12
Q5_K = 13
Q6_K = 14
Q8_K = 15
I8 = 16
I16 = 17
I32 = 18
I64 = 19
F64 = 20
BF16 = 21
# aliases for older ggml versions where the integer ids shift by +1
I8_ALT = 17
I16_ALT = 18
I32_ALT = 19
I64_ALT = 20
class GGUFMetaValueType(IntEnum):
UINT8 = 0
INT8 = 1
UINT16 = 2
INT16 = 3
UINT32 = 4
INT32 = 5
FLOAT32 = 6
BOOL = 7
STRING = 8
ARRAY = 9
UINT64 = 10
INT64 = 11
FLOAT64 = 12
_META_VALUE_STRUCTS = {
GGUFMetaValueType.UINT8: _UINT8,
GGUFMetaValueType.INT8: struct.Struct("<b"),
GGUFMetaValueType.UINT16: _UINT16,
GGUFMetaValueType.INT16: struct.Struct("<h"),
GGUFMetaValueType.UINT32: _UINT32,
GGUFMetaValueType.INT32: struct.Struct("<i"),
GGUFMetaValueType.FLOAT32: struct.Struct("<f"),
GGUFMetaValueType.UINT64: _UINT64,
GGUFMetaValueType.INT64: struct.Struct("<q"),
GGUFMetaValueType.FLOAT64: struct.Struct("<d"),
}
# ────────────────────────────
# De-quant helpers
# ────────────────────────────
def _dequant_q8_0(buf: memoryview, dst: np.ndarray, n_blocks: int) -> None:
pos, BS = 0, 32
for i in range(n_blocks):
scale = np.float32(_FP16.unpack_from(buf, pos)[0]); pos += 2
q = np.frombuffer(buf[pos:pos+BS], dtype=np.int8).astype(np.float32); pos += BS
dst[i*BS:(i+1)*BS] = q * scale
def _dequant_q4_0(buf: memoryview, dst: np.ndarray, n_blocks: int) -> None:
pos, BS = 0, 32
for i in range(n_blocks):
scale = np.float32(_FP16.unpack_from(buf, pos)[0]); pos += 2
qbytes = np.frombuffer(buf[pos:pos+16], dtype=np.uint8); pos += 16
# Unpack to two arrays of 4-bit values and center them
q_lo = (qbytes & 0x0F).astype(np.float32) - 8.0
q_hi = (qbytes >> 4).astype(np.float32) - 8.0
# Interleave and dequantize
blk = np.empty(BS, dtype=np.float32)
blk[0::2] = q_lo * scale
blk[1::2] = q_hi * scale
dst[i*BS:(i+1)*BS] = blk
def _dequant_q4_1(buf: memoryview, dst: np.ndarray, n_blocks: int) -> None:
pos, BS = 0, 32
for i in range(n_blocks):
scale = np.float32(_FP16.unpack_from(buf, pos)[0]); pos += 2
min_ = np.float32(_FP16.unpack_from(buf, pos)[0]); pos += 2
qbytes = np.frombuffer(buf[pos:pos+16], dtype=np.uint8); pos += 16
lo = (qbytes & 0x0F).astype(np.float32)
hi = (qbytes >> 4).astype(np.float32)
blk = np.empty(BS, dtype=np.float32)
blk[0::2] = lo * scale + min_
blk[1::2] = hi * scale + min_
dst[i*BS:(i+1)*BS] = blk
def _dequant_q8_1(buf: memoryview, dst: np.ndarray, n_blocks: int) -> None:
pos, BS = 0, 32
for i in range(n_blocks):
scale = np.float32(_FP16.unpack_from(buf, pos)[0]); pos += 2
zero = np.float32(_FP16.unpack_from(buf, pos)[0]); pos += 2
q = np.frombuffer(buf[pos:pos+BS], dtype=np.int8).astype(np.float32); pos += BS
dst[i*BS:(i+1)*BS] = q * scale + zero
def _dequant_q4_k(buf: memoryview, dst: np.ndarray, n_blocks: int) -> None:
pos, BS = 0, 256 # Q4_K uses blocks of 256
for i in range(n_blocks):
# --- BUG FIX ---: Upcast scales to float64 to prevent intermediate overflow.
d = np.float64(_FP16.unpack_from(buf, pos)[0]); pos += 2
m = np.float64(_FP16.unpack_from(buf, pos)[0]); pos += 2
scales_and_mins_raw = np.frombuffer(buf[pos:pos+16], dtype=np.uint8); pos += 16
scales = (scales_and_mins_raw & 0x0F).astype(np.float64) * d
mins = (scales_and_mins_raw >> 4).astype(np.float64) * m
q = np.frombuffer(buf[pos:pos+128], dtype=np.uint8); pos += 128
blk_dst = dst[i*BS:(i+1)*BS]
for j in range(16):
start, end = j*16, (j+1)*16
q_lo = q[j*8:(j+1)*8] & 0x0F
q_hi = q[j*8:(j+1)*8] >> 4
# sub_q is float32, but scales/mins are float64, so math is done in float64
sub_q = np.concatenate([q_lo, q_hi]).astype(np.float32)
blk_dst[start:end] = sub_q * scales[j] + mins[j]
def _dequant_q6_k(buf: memoryview, dst: np.ndarray, n_blocks: int) -> None:
pos, BS = 0, 256
for i in range(n_blocks):
# --- BUG FIX ---: Upcast scales to float64 to prevent intermediate overflow.
scale_super = np.float64(_FP16.unpack_from(buf, pos)[0]); pos += 2
scales = np.frombuffer(buf[pos:pos+16], dtype=np.int8).astype(np.float64); pos += 16
ql = np.frombuffer(buf[pos:pos+128], dtype=np.uint8); pos += 128
qh = np.frombuffer(buf[pos:pos+64], dtype=np.uint8); pos += 64
# Reconstruct the 6-bit quantized values.
q = np.empty(BS, dtype=np.int8)
q[0:128] = (ql & 0x0F)
q[128:256] = (ql >> 4)
q[0:64] |= ((qh >> 0) & 3) << 4
q[64:128] |= ((qh >> 2) & 3) << 4
q[128:192] |= ((qh >> 4) & 3) << 4
q[192:256] |= ((qh >> 6) & 3) << 4
# q_final is float32, but scale_super and scales are float64, so math is done in float64
q_final = (q - 32).astype(np.float32)
blk_dst = dst[i*BS:(i+1)*BS]
for j in range(16):
start, end = j*16, (j+1)*16
blk_dst[start:end] = scale_super * scales[j] * q_final[start:end]
# Type-info map
_TYPE_INFO: Dict[GGMLType, Dict[str, Any]] = {
# primitive
GGMLType.F32: dict(torch_dtype=torch.float32),
GGMLType.F16: dict(torch_dtype=torch.float16),
GGMLType.BF16: dict(torch_dtype=torch.bfloat16),
GGMLType.F64: dict(torch_dtype=torch.float64),
GGMLType.I8: dict(torch_dtype=torch.int8),
GGMLType.I16: dict(torch_dtype=torch.int16),
GGMLType.I32: dict(torch_dtype=torch.int32),
GGMLType.I64: dict(torch_dtype=torch.int64),
GGMLType.I8_ALT: dict(torch_dtype=torch.int8),
GGMLType.I16_ALT: dict(torch_dtype=torch.int16),
GGMLType.I32_ALT: dict(torch_dtype=torch.int32),
GGMLType.I64_ALT: dict(torch_dtype=torch.int64),
# quantised (supported)
GGMLType.Q8_0: dict(block=32, size=34, dequant=_dequant_q8_0),
GGMLType.Q4_0: dict(block=32, size=18, dequant=_dequant_q4_0),
GGMLType.Q4_1: dict(block=32, size=20, dequant=_dequant_q4_1),
GGMLType.Q8_1: dict(block=32, size=36, dequant=_dequant_q8_1),
GGMLType.Q4_K: dict(block=256, size=148, dequant=_dequant_q4_k),
GGMLType.Q6_K: dict(block=256, size=210, dequant=_dequant_q6_k),
# quantised (not yet ported)
GGMLType.Q2_K: dict(supported=False),
GGMLType.Q3_K: dict(supported=False),
GGMLType.Q5_K: dict(supported=False),
GGMLType.Q8_K: dict(supported=False),
}
# ────────────────────────────
# meta helpers
# ────────────────────────────
def _read_string(buf: BinaryIO, file_size: int) -> str:
length = _UINT64.unpack(buf.read(8))[0]
if buf.tell() + length > file_size:
raise ValueError("GGUF file truncated while reading string")
return buf.read(length).decode("utf-8")
def _read_meta_value(buf: BinaryIO, vt: GGUFMetaValueType, file_size: int) -> Any:
s = _META_VALUE_STRUCTS.get(vt)
if s:
return s.unpack(buf.read(s.size))[0]
if vt == GGUFMetaValueType.BOOL:
return struct.unpack("<?", buf.read(1))[0]
if vt == GGUFMetaValueType.STRING:
return _read_string(buf, file_size)
if vt == GGUFMetaValueType.ARRAY:
item_type = GGUFMetaValueType(_UINT32.unpack(buf.read(4))[0])
count = _UINT64.unpack(buf.read(8))[0]
return [_read_meta_value(buf, item_type, file_size) for _ in range(count)]
raise ValueError(f"unhandled meta type {vt}")
# ────────────────────────────
# tensor-descriptor
# ────────────────────────────
class _TensorInfo:
__slots__ = ("name", "shape", "dtype", "offset", "n_bytes")
def __init__(self, name: str, shape: List[int], dtype: GGMLType, offset: int):
self.name = name
self.shape = shape
self.dtype = dtype
self.offset = offset
self.n_bytes = 0
# ────────────────────────────
# public loader
# ────────────────────────────
class GGUFLoader:
def __init__(self, path: str | pathlib.Path, target_dtype: torch.dtype = torch.float16):
self._path = pathlib.Path(path)
self._target_dtype = target_dtype
def load(self) -> Tuple[Dict[str, Any], Dict[str, torch.Tensor]]:
if not self._path.is_file():
raise FileNotFoundError(self._path)
meta, tensors, tensor_data_buffers = self._read_structure()
weights: Dict[str, torch.Tensor] = {}
# --- ADDED PROGRESS BAR ---
tensor_iterator = zip(tensors, tensor_data_buffers)
if tqdm:
tensor_iterator = tqdm.tqdm(
tensor_iterator,
total=len(tensors),
desc="Loading tensors",
unit="tensors"
)
for t, raw_data in tensor_iterator:
info = _TYPE_INFO.get(t.dtype)
if info is None:
raise ValueError(f"{t.name}: unknown GGMLType {t.dtype}")
if not info.get("supported", True):
raise NotImplementedError(f"{t.dtype.name} not yet supported")
numel = int(np.prod(t.shape)) if t.shape else 1
if "torch_dtype" in info:
buf = bytearray(raw_data)
tensor = torch.frombuffer(buf, dtype=info["torch_dtype"]).reshape(*t.shape).clone()
else:
bs, fn = info["block"], info["dequant"]
out = np.empty(numel, dtype=np.float32)
fn(memoryview(raw_data), out, numel // bs)
tensor = torch.from_numpy(out).reshape(*t.shape)
weights[t.name] = tensor.to(self._target_dtype)
return meta, weights
def _read_structure(self) -> Tuple[Dict[str, Any], List[_TensorInfo], List[bytes]]:
tensor_data_buffers = []
with self._path.open("rb") as fh, mmap.mmap(fh.fileno(), 0, access=mmap.ACCESS_READ) as mm:
file_size = len(mm)
bf = io.BytesIO(mm)
magic, version, n_tensors, n_meta = _HDR.unpack(bf.read(_HDR.size))
if magic != _MAGIC:
raise ValueError("not a GGUF file")
if version > 3:
print(f"warning: GGUF v{version} – parser tested up to v3")
meta: Dict[str, Any] = {}
for _ in range(n_meta):
key = _read_string(bf, file_size)
vt = GGUFMetaValueType(_UINT32.unpack(bf.read(4))[0])
meta[key] = _read_meta_value(bf, vt, file_size)
alignment = int(meta.get("general.alignment", 32))
tensors: List[_TensorInfo] = []
for _ in range(n_tensors):
name = _read_string(bf, file_size)
ndims = _UINT32.unpack(bf.read(4))[0]
dims = [_UINT64.unpack(bf.read(8))[0] for _ in range(ndims)]
shape = list(reversed(dims))
dtype = GGMLType(_UINT32.unpack(bf.read(4))[0])
offset = _UINT64.unpack(bf.read(8))[0]
tensors.append(_TensorInfo(name, shape, dtype, offset))
data_start_offset = bf.tell()
padding = -data_start_offset % alignment
data_start_offset += padding
for t in tensors:
t.offset += data_start_offset
numel = int(np.prod(t.shape)) if t.shape else 1
info = _TYPE_INFO.get(t.dtype)
if info is None:
raise ValueError(f"{t.name}: unknown GGMLType {t.dtype}")
if "torch_dtype" in info:
elt = torch.tensor([], dtype=info["torch_dtype"])
t.n_bytes = numel * elt.element_size()
elif "block" in info:
bs, blk_bytes = info["block"], info["size"]
if numel % bs:
raise ValueError(f"{t.name}: {numel} not mult of {bs}")
t.n_bytes = (numel // bs) * blk_bytes
else:
if not info.get("supported", True):
t.n_bytes = 0
else:
raise ValueError(f"Tensor {t.name} has quant type {t.dtype.name} with no size info.")
if t.offset % alignment:
raise ValueError(f"{t.name} offset {t.offset} not aligned ({alignment})")
if t.n_bytes > 0:
tensor_data_buffers.append(mm[t.offset : t.offset + t.n_bytes])
elif not info.get("supported", True):
tensor_data_buffers.append(b'')
return meta, tensors, tensor_data_buffers
# quick CLI sanity
if __name__ == "__main__":
import sys, textwrap
# To install tqdm for the progress bar: pip install tqdm
if not tqdm:
print("Note: 'tqdm' not found. To see a progress bar, run: pip install tqdm", file=sys.stderr)
# Make sure to update this path to your file
gguf_path = r'D:\ia\levi_chat\llm\Llama-3.2-1B-Instruct-Q4_K_M.gguf'
try:
print(f"Loading: {gguf_path}")
meta, w = GGUFLoader(gguf_path).load()
print("\n--- Metadata ---")
for key, value in list(meta.items())[:15]:
if isinstance(value, list):
print(f"- {key}: (list of {len(value)} items)")
else:
print(f"- {key}: {textwrap.shorten(str(value), 100)}")
print("\n--- Tensors ---")
first_key = next(iter(w))
print(f"Total tensors loaded: {len(w)}")
print(f"First tensor: '{first_key}' | Shape: {tuple(w[first_key].shape)} | DType: {w[first_key].dtype}")
print("\nSuccessfully loaded GGUF file.")
except (FileNotFoundError, NotImplementedError, ValueError) as e:
print(f"\nError: {e}", file=sys.stderr)
sys.exit(1)
except Exception as e:
import traceback
print(f"\nAn unexpected error occurred: {e}", file=sys.stderr)
traceback.print_exc()
sys.exit(1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment