Last active
June 19, 2025 03:17
-
-
Save celsowm/6706529abd1d1134a0012d81bf0ed256 to your computer and use it in GitHub Desktop.
python/sglang/srt/weight_loader/gguf_loader.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import io | |
import mmap | |
import pathlib | |
import struct | |
from enum import IntEnum | |
from typing import Any, BinaryIO, Dict, List, Tuple | |
import numpy as np | |
import torch | |
# Optional tqdm for progress bar | |
try: | |
import tqdm | |
except ImportError: | |
tqdm = None | |
# ──────────────────────────── | |
# basic binary helpers | |
# ──────────────────────────── | |
_MAGIC = b"GGUF" | |
_HDR = struct.Struct("<4sIQQ") # magic, version, n_tensors, n_meta | |
_UINT8 = struct.Struct("<B") | |
_UINT16 = struct.Struct("<H") | |
_UINT32 = struct.Struct("<I") | |
_UINT64 = struct.Struct("<Q") | |
_FP16 = struct.Struct("<e") # little-endian IEEE-754 half | |
# ──────────────────────────── | |
# GGML / GGUF enums | |
# ──────────────────────────── | |
class GGMLType(IntEnum): | |
F32 = 0 | |
F16 = 1 | |
Q4_0 = 2 | |
Q4_1 = 3 | |
Q8_0 = 6 | |
Q8_1 = 7 | |
Q2_K = 10 | |
Q3_K = 11 | |
Q4_K = 12 | |
Q5_K = 13 | |
Q6_K = 14 | |
Q8_K = 15 | |
I8 = 16 | |
I16 = 17 | |
I32 = 18 | |
I64 = 19 | |
F64 = 20 | |
BF16 = 21 | |
# aliases for older ggml versions where the integer ids shift by +1 | |
I8_ALT = 17 | |
I16_ALT = 18 | |
I32_ALT = 19 | |
I64_ALT = 20 | |
class GGUFMetaValueType(IntEnum): | |
UINT8 = 0 | |
INT8 = 1 | |
UINT16 = 2 | |
INT16 = 3 | |
UINT32 = 4 | |
INT32 = 5 | |
FLOAT32 = 6 | |
BOOL = 7 | |
STRING = 8 | |
ARRAY = 9 | |
UINT64 = 10 | |
INT64 = 11 | |
FLOAT64 = 12 | |
_META_VALUE_STRUCTS = { | |
GGUFMetaValueType.UINT8: _UINT8, | |
GGUFMetaValueType.INT8: struct.Struct("<b"), | |
GGUFMetaValueType.UINT16: _UINT16, | |
GGUFMetaValueType.INT16: struct.Struct("<h"), | |
GGUFMetaValueType.UINT32: _UINT32, | |
GGUFMetaValueType.INT32: struct.Struct("<i"), | |
GGUFMetaValueType.FLOAT32: struct.Struct("<f"), | |
GGUFMetaValueType.UINT64: _UINT64, | |
GGUFMetaValueType.INT64: struct.Struct("<q"), | |
GGUFMetaValueType.FLOAT64: struct.Struct("<d"), | |
} | |
# ──────────────────────────── | |
# De-quant helpers | |
# ──────────────────────────── | |
def _dequant_q8_0(buf: memoryview, dst: np.ndarray, n_blocks: int) -> None: | |
pos, BS = 0, 32 | |
for i in range(n_blocks): | |
scale = np.float32(_FP16.unpack_from(buf, pos)[0]); pos += 2 | |
q = np.frombuffer(buf[pos:pos+BS], dtype=np.int8).astype(np.float32); pos += BS | |
dst[i*BS:(i+1)*BS] = q * scale | |
def _dequant_q4_0(buf: memoryview, dst: np.ndarray, n_blocks: int) -> None: | |
pos, BS = 0, 32 | |
for i in range(n_blocks): | |
scale = np.float32(_FP16.unpack_from(buf, pos)[0]); pos += 2 | |
qbytes = np.frombuffer(buf[pos:pos+16], dtype=np.uint8); pos += 16 | |
# Unpack to two arrays of 4-bit values and center them | |
q_lo = (qbytes & 0x0F).astype(np.float32) - 8.0 | |
q_hi = (qbytes >> 4).astype(np.float32) - 8.0 | |
# Interleave and dequantize | |
blk = np.empty(BS, dtype=np.float32) | |
blk[0::2] = q_lo * scale | |
blk[1::2] = q_hi * scale | |
dst[i*BS:(i+1)*BS] = blk | |
def _dequant_q4_1(buf: memoryview, dst: np.ndarray, n_blocks: int) -> None: | |
pos, BS = 0, 32 | |
for i in range(n_blocks): | |
scale = np.float32(_FP16.unpack_from(buf, pos)[0]); pos += 2 | |
min_ = np.float32(_FP16.unpack_from(buf, pos)[0]); pos += 2 | |
qbytes = np.frombuffer(buf[pos:pos+16], dtype=np.uint8); pos += 16 | |
lo = (qbytes & 0x0F).astype(np.float32) | |
hi = (qbytes >> 4).astype(np.float32) | |
blk = np.empty(BS, dtype=np.float32) | |
blk[0::2] = lo * scale + min_ | |
blk[1::2] = hi * scale + min_ | |
dst[i*BS:(i+1)*BS] = blk | |
def _dequant_q8_1(buf: memoryview, dst: np.ndarray, n_blocks: int) -> None: | |
pos, BS = 0, 32 | |
for i in range(n_blocks): | |
scale = np.float32(_FP16.unpack_from(buf, pos)[0]); pos += 2 | |
zero = np.float32(_FP16.unpack_from(buf, pos)[0]); pos += 2 | |
q = np.frombuffer(buf[pos:pos+BS], dtype=np.int8).astype(np.float32); pos += BS | |
dst[i*BS:(i+1)*BS] = q * scale + zero | |
def _dequant_q4_k(buf: memoryview, dst: np.ndarray, n_blocks: int) -> None: | |
pos, BS = 0, 256 # Q4_K uses blocks of 256 | |
for i in range(n_blocks): | |
# --- BUG FIX ---: Upcast scales to float64 to prevent intermediate overflow. | |
d = np.float64(_FP16.unpack_from(buf, pos)[0]); pos += 2 | |
m = np.float64(_FP16.unpack_from(buf, pos)[0]); pos += 2 | |
scales_and_mins_raw = np.frombuffer(buf[pos:pos+16], dtype=np.uint8); pos += 16 | |
scales = (scales_and_mins_raw & 0x0F).astype(np.float64) * d | |
mins = (scales_and_mins_raw >> 4).astype(np.float64) * m | |
q = np.frombuffer(buf[pos:pos+128], dtype=np.uint8); pos += 128 | |
blk_dst = dst[i*BS:(i+1)*BS] | |
for j in range(16): | |
start, end = j*16, (j+1)*16 | |
q_lo = q[j*8:(j+1)*8] & 0x0F | |
q_hi = q[j*8:(j+1)*8] >> 4 | |
# sub_q is float32, but scales/mins are float64, so math is done in float64 | |
sub_q = np.concatenate([q_lo, q_hi]).astype(np.float32) | |
blk_dst[start:end] = sub_q * scales[j] + mins[j] | |
def _dequant_q6_k(buf: memoryview, dst: np.ndarray, n_blocks: int) -> None: | |
pos, BS = 0, 256 | |
for i in range(n_blocks): | |
# --- BUG FIX ---: Upcast scales to float64 to prevent intermediate overflow. | |
scale_super = np.float64(_FP16.unpack_from(buf, pos)[0]); pos += 2 | |
scales = np.frombuffer(buf[pos:pos+16], dtype=np.int8).astype(np.float64); pos += 16 | |
ql = np.frombuffer(buf[pos:pos+128], dtype=np.uint8); pos += 128 | |
qh = np.frombuffer(buf[pos:pos+64], dtype=np.uint8); pos += 64 | |
# Reconstruct the 6-bit quantized values. | |
q = np.empty(BS, dtype=np.int8) | |
q[0:128] = (ql & 0x0F) | |
q[128:256] = (ql >> 4) | |
q[0:64] |= ((qh >> 0) & 3) << 4 | |
q[64:128] |= ((qh >> 2) & 3) << 4 | |
q[128:192] |= ((qh >> 4) & 3) << 4 | |
q[192:256] |= ((qh >> 6) & 3) << 4 | |
# q_final is float32, but scale_super and scales are float64, so math is done in float64 | |
q_final = (q - 32).astype(np.float32) | |
blk_dst = dst[i*BS:(i+1)*BS] | |
for j in range(16): | |
start, end = j*16, (j+1)*16 | |
blk_dst[start:end] = scale_super * scales[j] * q_final[start:end] | |
# Type-info map | |
_TYPE_INFO: Dict[GGMLType, Dict[str, Any]] = { | |
# primitive | |
GGMLType.F32: dict(torch_dtype=torch.float32), | |
GGMLType.F16: dict(torch_dtype=torch.float16), | |
GGMLType.BF16: dict(torch_dtype=torch.bfloat16), | |
GGMLType.F64: dict(torch_dtype=torch.float64), | |
GGMLType.I8: dict(torch_dtype=torch.int8), | |
GGMLType.I16: dict(torch_dtype=torch.int16), | |
GGMLType.I32: dict(torch_dtype=torch.int32), | |
GGMLType.I64: dict(torch_dtype=torch.int64), | |
GGMLType.I8_ALT: dict(torch_dtype=torch.int8), | |
GGMLType.I16_ALT: dict(torch_dtype=torch.int16), | |
GGMLType.I32_ALT: dict(torch_dtype=torch.int32), | |
GGMLType.I64_ALT: dict(torch_dtype=torch.int64), | |
# quantised (supported) | |
GGMLType.Q8_0: dict(block=32, size=34, dequant=_dequant_q8_0), | |
GGMLType.Q4_0: dict(block=32, size=18, dequant=_dequant_q4_0), | |
GGMLType.Q4_1: dict(block=32, size=20, dequant=_dequant_q4_1), | |
GGMLType.Q8_1: dict(block=32, size=36, dequant=_dequant_q8_1), | |
GGMLType.Q4_K: dict(block=256, size=148, dequant=_dequant_q4_k), | |
GGMLType.Q6_K: dict(block=256, size=210, dequant=_dequant_q6_k), | |
# quantised (not yet ported) | |
GGMLType.Q2_K: dict(supported=False), | |
GGMLType.Q3_K: dict(supported=False), | |
GGMLType.Q5_K: dict(supported=False), | |
GGMLType.Q8_K: dict(supported=False), | |
} | |
# ──────────────────────────── | |
# meta helpers | |
# ──────────────────────────── | |
def _read_string(buf: BinaryIO, file_size: int) -> str: | |
length = _UINT64.unpack(buf.read(8))[0] | |
if buf.tell() + length > file_size: | |
raise ValueError("GGUF file truncated while reading string") | |
return buf.read(length).decode("utf-8") | |
def _read_meta_value(buf: BinaryIO, vt: GGUFMetaValueType, file_size: int) -> Any: | |
s = _META_VALUE_STRUCTS.get(vt) | |
if s: | |
return s.unpack(buf.read(s.size))[0] | |
if vt == GGUFMetaValueType.BOOL: | |
return struct.unpack("<?", buf.read(1))[0] | |
if vt == GGUFMetaValueType.STRING: | |
return _read_string(buf, file_size) | |
if vt == GGUFMetaValueType.ARRAY: | |
item_type = GGUFMetaValueType(_UINT32.unpack(buf.read(4))[0]) | |
count = _UINT64.unpack(buf.read(8))[0] | |
return [_read_meta_value(buf, item_type, file_size) for _ in range(count)] | |
raise ValueError(f"unhandled meta type {vt}") | |
# ──────────────────────────── | |
# tensor-descriptor | |
# ──────────────────────────── | |
class _TensorInfo: | |
__slots__ = ("name", "shape", "dtype", "offset", "n_bytes") | |
def __init__(self, name: str, shape: List[int], dtype: GGMLType, offset: int): | |
self.name = name | |
self.shape = shape | |
self.dtype = dtype | |
self.offset = offset | |
self.n_bytes = 0 | |
# ──────────────────────────── | |
# public loader | |
# ──────────────────────────── | |
class GGUFLoader: | |
def __init__(self, path: str | pathlib.Path, target_dtype: torch.dtype = torch.float16): | |
self._path = pathlib.Path(path) | |
self._target_dtype = target_dtype | |
def load(self) -> Tuple[Dict[str, Any], Dict[str, torch.Tensor]]: | |
if not self._path.is_file(): | |
raise FileNotFoundError(self._path) | |
meta, tensors, tensor_data_buffers = self._read_structure() | |
weights: Dict[str, torch.Tensor] = {} | |
# --- ADDED PROGRESS BAR --- | |
tensor_iterator = zip(tensors, tensor_data_buffers) | |
if tqdm: | |
tensor_iterator = tqdm.tqdm( | |
tensor_iterator, | |
total=len(tensors), | |
desc="Loading tensors", | |
unit="tensors" | |
) | |
for t, raw_data in tensor_iterator: | |
info = _TYPE_INFO.get(t.dtype) | |
if info is None: | |
raise ValueError(f"{t.name}: unknown GGMLType {t.dtype}") | |
if not info.get("supported", True): | |
raise NotImplementedError(f"{t.dtype.name} not yet supported") | |
numel = int(np.prod(t.shape)) if t.shape else 1 | |
if "torch_dtype" in info: | |
buf = bytearray(raw_data) | |
tensor = torch.frombuffer(buf, dtype=info["torch_dtype"]).reshape(*t.shape).clone() | |
else: | |
bs, fn = info["block"], info["dequant"] | |
out = np.empty(numel, dtype=np.float32) | |
fn(memoryview(raw_data), out, numel // bs) | |
tensor = torch.from_numpy(out).reshape(*t.shape) | |
weights[t.name] = tensor.to(self._target_dtype) | |
return meta, weights | |
def _read_structure(self) -> Tuple[Dict[str, Any], List[_TensorInfo], List[bytes]]: | |
tensor_data_buffers = [] | |
with self._path.open("rb") as fh, mmap.mmap(fh.fileno(), 0, access=mmap.ACCESS_READ) as mm: | |
file_size = len(mm) | |
bf = io.BytesIO(mm) | |
magic, version, n_tensors, n_meta = _HDR.unpack(bf.read(_HDR.size)) | |
if magic != _MAGIC: | |
raise ValueError("not a GGUF file") | |
if version > 3: | |
print(f"warning: GGUF v{version} – parser tested up to v3") | |
meta: Dict[str, Any] = {} | |
for _ in range(n_meta): | |
key = _read_string(bf, file_size) | |
vt = GGUFMetaValueType(_UINT32.unpack(bf.read(4))[0]) | |
meta[key] = _read_meta_value(bf, vt, file_size) | |
alignment = int(meta.get("general.alignment", 32)) | |
tensors: List[_TensorInfo] = [] | |
for _ in range(n_tensors): | |
name = _read_string(bf, file_size) | |
ndims = _UINT32.unpack(bf.read(4))[0] | |
dims = [_UINT64.unpack(bf.read(8))[0] for _ in range(ndims)] | |
shape = list(reversed(dims)) | |
dtype = GGMLType(_UINT32.unpack(bf.read(4))[0]) | |
offset = _UINT64.unpack(bf.read(8))[0] | |
tensors.append(_TensorInfo(name, shape, dtype, offset)) | |
data_start_offset = bf.tell() | |
padding = -data_start_offset % alignment | |
data_start_offset += padding | |
for t in tensors: | |
t.offset += data_start_offset | |
numel = int(np.prod(t.shape)) if t.shape else 1 | |
info = _TYPE_INFO.get(t.dtype) | |
if info is None: | |
raise ValueError(f"{t.name}: unknown GGMLType {t.dtype}") | |
if "torch_dtype" in info: | |
elt = torch.tensor([], dtype=info["torch_dtype"]) | |
t.n_bytes = numel * elt.element_size() | |
elif "block" in info: | |
bs, blk_bytes = info["block"], info["size"] | |
if numel % bs: | |
raise ValueError(f"{t.name}: {numel} not mult of {bs}") | |
t.n_bytes = (numel // bs) * blk_bytes | |
else: | |
if not info.get("supported", True): | |
t.n_bytes = 0 | |
else: | |
raise ValueError(f"Tensor {t.name} has quant type {t.dtype.name} with no size info.") | |
if t.offset % alignment: | |
raise ValueError(f"{t.name} offset {t.offset} not aligned ({alignment})") | |
if t.n_bytes > 0: | |
tensor_data_buffers.append(mm[t.offset : t.offset + t.n_bytes]) | |
elif not info.get("supported", True): | |
tensor_data_buffers.append(b'') | |
return meta, tensors, tensor_data_buffers | |
# quick CLI sanity | |
if __name__ == "__main__": | |
import sys, textwrap | |
# To install tqdm for the progress bar: pip install tqdm | |
if not tqdm: | |
print("Note: 'tqdm' not found. To see a progress bar, run: pip install tqdm", file=sys.stderr) | |
# Make sure to update this path to your file | |
gguf_path = r'D:\ia\levi_chat\llm\Llama-3.2-1B-Instruct-Q4_K_M.gguf' | |
try: | |
print(f"Loading: {gguf_path}") | |
meta, w = GGUFLoader(gguf_path).load() | |
print("\n--- Metadata ---") | |
for key, value in list(meta.items())[:15]: | |
if isinstance(value, list): | |
print(f"- {key}: (list of {len(value)} items)") | |
else: | |
print(f"- {key}: {textwrap.shorten(str(value), 100)}") | |
print("\n--- Tensors ---") | |
first_key = next(iter(w)) | |
print(f"Total tensors loaded: {len(w)}") | |
print(f"First tensor: '{first_key}' | Shape: {tuple(w[first_key].shape)} | DType: {w[first_key].dtype}") | |
print("\nSuccessfully loaded GGUF file.") | |
except (FileNotFoundError, NotImplementedError, ValueError) as e: | |
print(f"\nError: {e}", file=sys.stderr) | |
sys.exit(1) | |
except Exception as e: | |
import traceback | |
print(f"\nAn unexpected error occurred: {e}", file=sys.stderr) | |
traceback.print_exc() | |
sys.exit(1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment