Created
August 16, 2025 17:52
-
-
Save ehzawad/e637ad590b9812061665ab7927c63bb2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# PyTorch 2.x; Apple Silicon; runs on MPS if available, else CPU. | |
import math | |
import torch | |
if torch.backends.mps.is_available(): # official API to detect MPS | |
device = torch.device("mps") | |
else: | |
device = torch.device("cpu") | |
# --- Quantization helpers (per-tensor, symmetric) --- | |
@torch.no_grad() | |
def quantize_4bit_per_tensor(x: torch.Tensor): | |
""" | |
Symmetric per-tensor 4-bit: q in [-7, 7], zero-point = 0. | |
Returns (q_int, scale, x_hat). q_int is stored in int8 for convenience. | |
""" | |
assert x.is_floating_point() | |
s = x.abs().max() / 7.0 | |
s = torch.clamp(s, min=torch.finfo(x.dtype).eps) # avoid /0 | |
q = torch.round(x / s).clamp_(-7, 7) | |
q_int8 = q.to(torch.int8) | |
x_hat = (q_int8.to(x.dtype)) * s | |
return q_int8, s, x_hat | |
@torch.no_grad() | |
def quantize_1bit_sign_meanabs(x: torch.Tensor): | |
""" | |
Binary baseline: q in {-1, +1} with shared scale = mean(|x|). | |
Returns (q_sign, scale, x_hat). q_sign is int8 in {-1, +1}. | |
""" | |
assert x.is_floating_point() | |
# sign(0) -> +1 to keep zero representable without biasing the mean too much | |
q = torch.where(x >= 0, torch.ones_like(x), -torch.ones_like(x)) | |
s = x.abs().mean() | |
q_int8 = q.to(torch.int8) | |
x_hat = (q_int8.to(x.dtype)) * s | |
return q_int8, s, x_hat | |
@torch.no_grad() | |
def mse(a: torch.Tensor, b: torch.Tensor): | |
return torch.mean((a - b) ** 2).item() | |
# --- Toy layer and data on MPS --- | |
torch.manual_seed(0) | |
in_features, out_features, batch = 256, 256, 8 | |
W = torch.randn(out_features, in_features, device=device, dtype=torch.float32) | |
x = torch.randn(batch, in_features, device=device, dtype=torch.float32) | |
# Baseline output | |
y_ref = x @ W.t() | |
# 4-bit quantize/dequantize | |
q4, s4, W4_hat = quantize_4bit_per_tensor(W) | |
y4 = x @ W4_hat.t() | |
# 1-bit quantize/dequantize | |
q1, s1, W1_hat = quantize_1bit_sign_meanabs(W) | |
y1 = x @ W1_hat.t() | |
print(f"Device: {device}, W dtype: {W.dtype}") | |
print(f"MSE weight recon: 4-bit={mse(W, W4_hat):.6f} | 1-bit={mse(W, W1_hat):.6f}") | |
print(f"MSE output delta: 4-bit={mse(y_ref, y4):.6f} | 1-bit={mse(y_ref, y1):.6f}") | |
# Optional: host-side nibble packing to illustrate memory footprint (not used by MPS) | |
def pack_int4_signed(q_int8: torch.Tensor) -> torch.Tensor: | |
""" | |
Packs two signed 4-bit values from q in [-7,7] into one uint8 byte. | |
Returns a CPU tensor of dtype uint8 for illustration. | |
""" | |
q_cpu = q_int8.detach().to("cpu") | |
u = torch.clamp(q_cpu + 8, 0, 15).to(torch.uint8) # map [-7,7] -> [1,15], 0 unused | |
if u.numel() % 2 == 1: | |
u = torch.cat([u, torch.zeros(1, dtype=torch.uint8)], dim=0) | |
u0 = u[0::2] | |
u1 = u[1::2] | |
packed = u0 | (u1 << 4) | |
return packed | |
packed_q4 = pack_int4_signed(q4) | |
bytes_fp32 = W.numel() * 4 | |
bytes_int4 = packed_q4.numel() # one byte stores two weights | |
print(f"Storage illustration: fp32 ~{bytes_fp32/1e6:.2f} MB vs int4 packed ~{bytes_int4/1e6:.2f} MB (≈8x smaller than fp32)") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment