Skip to content

Instantly share code, notes, and snippets.

@aaaddress1
Last active March 30, 2026 10:14
Show Gist options
  • Select an option

  • Save aaaddress1/a226e5e401b02a935805fabc97552db1 to your computer and use it in GitHub Desktop.

Select an option

Save aaaddress1/a226e5e401b02a935805fabc97552db1 to your computer and use it in GitHub Desktop.
Toy TurboQuant Note for DeepNind Research ICLR'26
import math
import torch
# ============================================================
# TurboQuant toy demo for a beginner who only knows MHA basics
# ------------------------------------------------------------
# This script shows:
# 1. Long prompt -> hidden states -> key vectors
# 2. Algorithm 1 style compression:
# rotate with Pi -> scalar quantize with codebook
# 3. Residual computation
# 4. Algorithm 2 style residual sketch:
# qjl = sign(S @ r), gamma = ||r||
# 5. Reconstruction
#
# This is a teaching demo, not the exact production implementation.
# ============================================================
torch.set_printoptions(precision=4, sci_mode=False)
# ------------------------------------------------------------
# Step 0: pretend we already have a long prompt
# In real LLMs, tokens come from tokenizer + embedding + layers.
# Here we only simulate one head's key vectors.
# ------------------------------------------------------------
T = 12 # prompt length (toy)
d_model = 8
head_dim = 4
torch.manual_seed(0)
# Fake hidden states for a long prompt
H = torch.randn(T, d_model)
# Fake Wk projection for one attention head
Wk = torch.randn(d_model, head_dim)
# Compute all key vectors for the prompt
K = H @ Wk # shape: [T, head_dim]
print("All key vectors K shape:", K.shape)
# Pick one token to inspect, e.g. token index 5
x = K[5].clone()
print("\nOriginal key vector x:")
print(x)
# ------------------------------------------------------------
# Step 1: define an orthogonal matrix Pi
# We use a fixed 4x4 Hadamard-like orthogonal matrix
# because it is easy to understand and verify by hand.
# In the paper, Pi is a random rotation matrix.
# ------------------------------------------------------------
Pi = 0.5 * torch.tensor([
[ 1.0, 1.0, 1.0, 1.0],
[ 1.0, -1.0, 1.0, -1.0],
[ 1.0, 1.0, -1.0, -1.0],
[ 1.0, -1.0, -1.0, 1.0],
])
print("\nPi^T @ Pi:")
print(Pi.T @ Pi)
# ------------------------------------------------------------
# Step 2: rotate x -> y = Pi @ x
# This changes the coordinate system but preserves geometry.
# ------------------------------------------------------------
y = Pi @ x
print("\nRotated vector y = Pi @ x:")
print(y)
# ------------------------------------------------------------
# Step 3: Algorithm 1 style scalar quantization
# We use a simple 2-bit codebook for teaching:
# 4 centroids = [-0.75, -0.25, 0.25, 0.75]
#
# In the paper, the codebook is optimized via Lloyd-Max.
# ------------------------------------------------------------
codebook = torch.tensor([-0.75, -0.25, 0.25, 0.75])
def encode_with_codebook(v, codebook):
diff = (v.unsqueeze(-1) - codebook.unsqueeze(0)).abs()
idx = diff.argmin(dim=-1)
return idx
def decode_with_codebook(idx, codebook):
return codebook[idx]
idx = encode_with_codebook(y, codebook)
y_hat_mse = decode_with_codebook(idx, codebook)
x_hat_mse = Pi.T @ y_hat_mse
print("\nCodebook:")
print(codebook)
print("\nEncoded indices idx:")
print(idx)
print("\nDecoded rotated vector y_hat_mse:")
print(y_hat_mse)
print("\nReconstructed vector x_hat_mse = Pi^T @ y_hat_mse:")
print(x_hat_mse)
# ------------------------------------------------------------
# Step 4: residual
# This is what Algorithm 2 tries to encode cheaply.
# ------------------------------------------------------------
r = x - x_hat_mse
gamma = r.norm()
print("\nResidual r = x - x_hat_mse:")
print(r)
print("\nResidual norm gamma:")
print(gamma.item())
# ------------------------------------------------------------
# Step 5: QJL-like residual sketch
# In the paper, S is a random Gaussian matrix.
# Here we reuse Pi for a simple toy demo.
# qjl = sign(S @ r)
# ------------------------------------------------------------
S = Pi.clone()
proj = S @ r
qjl = torch.sign(proj)
qjl[qjl == 0] = 1.0
print("\nProjected residual S @ r:")
print(proj)
print("\n1-bit sign sketch qjl:")
print(qjl)
# ------------------------------------------------------------
# Step 6: reconstruct residual approximately
# Paper-style correction term:
# r_hat = sqrt(pi/2)/d * gamma * S^T @ qjl
# ------------------------------------------------------------
r_hat = math.sqrt(math.pi / 2.0) / head_dim * gamma * (S.T @ qjl)
print("\nApprox residual reconstruction r_hat:")
print(r_hat)
# Final reconstruction
x_hat_final = x_hat_mse + r_hat
print("\nFinal reconstruction x_hat_final:")
print(x_hat_final)
# ------------------------------------------------------------
# Step 7: compare errors
# ------------------------------------------------------------
mse_only = ((x - x_hat_mse) ** 2).mean().item()
two_stage = ((x - x_hat_final) ** 2).mean().item()
print("\nMSE-only reconstruction error:", mse_only)
print("Two-stage reconstruction error:", two_stage)
# ------------------------------------------------------------
# Step 8: show effect on an attention-style inner product
# Pretend q is the new query vector for a future token.
# ------------------------------------------------------------
q = torch.randn(head_dim)
true_ip = torch.dot(q, x).item()
mse_ip = torch.dot(q, x_hat_mse).item()
final_ip = torch.dot(q, x_hat_final).item()
print("\nQuery vector q:")
print(q)
print("\nTrue inner product q·x :", true_ip)
print("MSE-only inner product q·xhat :", mse_ip)
print("Final inner product q·xhat :", final_ip)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment