Skip to content

Instantly share code, notes, and snippets.

@ehzawad
Created August 16, 2025 17:52
Show Gist options
  • Save ehzawad/e637ad590b9812061665ab7927c63bb2 to your computer and use it in GitHub Desktop.
Save ehzawad/e637ad590b9812061665ab7927c63bb2 to your computer and use it in GitHub Desktop.
# PyTorch 2.x; Apple Silicon; runs on MPS if available, else CPU.
import math
import torch
if torch.backends.mps.is_available(): # official API to detect MPS
device = torch.device("mps")
else:
device = torch.device("cpu")
# --- Quantization helpers (per-tensor, symmetric) ---
@torch.no_grad()
def quantize_4bit_per_tensor(x: torch.Tensor):
"""
Symmetric per-tensor 4-bit: q in [-7, 7], zero-point = 0.
Returns (q_int, scale, x_hat). q_int is stored in int8 for convenience.
"""
assert x.is_floating_point()
s = x.abs().max() / 7.0
s = torch.clamp(s, min=torch.finfo(x.dtype).eps) # avoid /0
q = torch.round(x / s).clamp_(-7, 7)
q_int8 = q.to(torch.int8)
x_hat = (q_int8.to(x.dtype)) * s
return q_int8, s, x_hat
@torch.no_grad()
def quantize_1bit_sign_meanabs(x: torch.Tensor):
"""
Binary baseline: q in {-1, +1} with shared scale = mean(|x|).
Returns (q_sign, scale, x_hat). q_sign is int8 in {-1, +1}.
"""
assert x.is_floating_point()
# sign(0) -> +1 to keep zero representable without biasing the mean too much
q = torch.where(x >= 0, torch.ones_like(x), -torch.ones_like(x))
s = x.abs().mean()
q_int8 = q.to(torch.int8)
x_hat = (q_int8.to(x.dtype)) * s
return q_int8, s, x_hat
@torch.no_grad()
def mse(a: torch.Tensor, b: torch.Tensor):
return torch.mean((a - b) ** 2).item()
# --- Toy layer and data on MPS ---
torch.manual_seed(0)
in_features, out_features, batch = 256, 256, 8
W = torch.randn(out_features, in_features, device=device, dtype=torch.float32)
x = torch.randn(batch, in_features, device=device, dtype=torch.float32)
# Baseline output
y_ref = x @ W.t()
# 4-bit quantize/dequantize
q4, s4, W4_hat = quantize_4bit_per_tensor(W)
y4 = x @ W4_hat.t()
# 1-bit quantize/dequantize
q1, s1, W1_hat = quantize_1bit_sign_meanabs(W)
y1 = x @ W1_hat.t()
print(f"Device: {device}, W dtype: {W.dtype}")
print(f"MSE weight recon: 4-bit={mse(W, W4_hat):.6f} | 1-bit={mse(W, W1_hat):.6f}")
print(f"MSE output delta: 4-bit={mse(y_ref, y4):.6f} | 1-bit={mse(y_ref, y1):.6f}")
# Optional: host-side nibble packing to illustrate memory footprint (not used by MPS)
def pack_int4_signed(q_int8: torch.Tensor) -> torch.Tensor:
"""
Packs two signed 4-bit values from q in [-7,7] into one uint8 byte.
Returns a CPU tensor of dtype uint8 for illustration.
"""
q_cpu = q_int8.detach().to("cpu")
u = torch.clamp(q_cpu + 8, 0, 15).to(torch.uint8) # map [-7,7] -> [1,15], 0 unused
if u.numel() % 2 == 1:
u = torch.cat([u, torch.zeros(1, dtype=torch.uint8)], dim=0)
u0 = u[0::2]
u1 = u[1::2]
packed = u0 | (u1 << 4)
return packed
packed_q4 = pack_int4_signed(q4)
bytes_fp32 = W.numel() * 4
bytes_int4 = packed_q4.numel() # one byte stores two weights
print(f"Storage illustration: fp32 ~{bytes_fp32/1e6:.2f} MB vs int4 packed ~{bytes_int4/1e6:.2f} MB (≈8x smaller than fp32)")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment