Last active
March 30, 2026 10:14
-
-
Save aaaddress1/a226e5e401b02a935805fabc97552db1 to your computer and use it in GitHub Desktop.
Toy TurboQuant Note for DeepNind Research ICLR'26
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import math | |
| import torch | |
| # ============================================================ | |
| # TurboQuant toy demo for a beginner who only knows MHA basics | |
| # ------------------------------------------------------------ | |
| # This script shows: | |
| # 1. Long prompt -> hidden states -> key vectors | |
| # 2. Algorithm 1 style compression: | |
| # rotate with Pi -> scalar quantize with codebook | |
| # 3. Residual computation | |
| # 4. Algorithm 2 style residual sketch: | |
| # qjl = sign(S @ r), gamma = ||r|| | |
| # 5. Reconstruction | |
| # | |
| # This is a teaching demo, not the exact production implementation. | |
| # ============================================================ | |
| torch.set_printoptions(precision=4, sci_mode=False) | |
| # ------------------------------------------------------------ | |
| # Step 0: pretend we already have a long prompt | |
| # In real LLMs, tokens come from tokenizer + embedding + layers. | |
| # Here we only simulate one head's key vectors. | |
| # ------------------------------------------------------------ | |
| T = 12 # prompt length (toy) | |
| d_model = 8 | |
| head_dim = 4 | |
| torch.manual_seed(0) | |
| # Fake hidden states for a long prompt | |
| H = torch.randn(T, d_model) | |
| # Fake Wk projection for one attention head | |
| Wk = torch.randn(d_model, head_dim) | |
| # Compute all key vectors for the prompt | |
| K = H @ Wk # shape: [T, head_dim] | |
| print("All key vectors K shape:", K.shape) | |
| # Pick one token to inspect, e.g. token index 5 | |
| x = K[5].clone() | |
| print("\nOriginal key vector x:") | |
| print(x) | |
| # ------------------------------------------------------------ | |
| # Step 1: define an orthogonal matrix Pi | |
| # We use a fixed 4x4 Hadamard-like orthogonal matrix | |
| # because it is easy to understand and verify by hand. | |
| # In the paper, Pi is a random rotation matrix. | |
| # ------------------------------------------------------------ | |
| Pi = 0.5 * torch.tensor([ | |
| [ 1.0, 1.0, 1.0, 1.0], | |
| [ 1.0, -1.0, 1.0, -1.0], | |
| [ 1.0, 1.0, -1.0, -1.0], | |
| [ 1.0, -1.0, -1.0, 1.0], | |
| ]) | |
| print("\nPi^T @ Pi:") | |
| print(Pi.T @ Pi) | |
| # ------------------------------------------------------------ | |
| # Step 2: rotate x -> y = Pi @ x | |
| # This changes the coordinate system but preserves geometry. | |
| # ------------------------------------------------------------ | |
| y = Pi @ x | |
| print("\nRotated vector y = Pi @ x:") | |
| print(y) | |
| # ------------------------------------------------------------ | |
| # Step 3: Algorithm 1 style scalar quantization | |
| # We use a simple 2-bit codebook for teaching: | |
| # 4 centroids = [-0.75, -0.25, 0.25, 0.75] | |
| # | |
| # In the paper, the codebook is optimized via Lloyd-Max. | |
| # ------------------------------------------------------------ | |
| codebook = torch.tensor([-0.75, -0.25, 0.25, 0.75]) | |
| def encode_with_codebook(v, codebook): | |
| diff = (v.unsqueeze(-1) - codebook.unsqueeze(0)).abs() | |
| idx = diff.argmin(dim=-1) | |
| return idx | |
| def decode_with_codebook(idx, codebook): | |
| return codebook[idx] | |
| idx = encode_with_codebook(y, codebook) | |
| y_hat_mse = decode_with_codebook(idx, codebook) | |
| x_hat_mse = Pi.T @ y_hat_mse | |
| print("\nCodebook:") | |
| print(codebook) | |
| print("\nEncoded indices idx:") | |
| print(idx) | |
| print("\nDecoded rotated vector y_hat_mse:") | |
| print(y_hat_mse) | |
| print("\nReconstructed vector x_hat_mse = Pi^T @ y_hat_mse:") | |
| print(x_hat_mse) | |
| # ------------------------------------------------------------ | |
| # Step 4: residual | |
| # This is what Algorithm 2 tries to encode cheaply. | |
| # ------------------------------------------------------------ | |
| r = x - x_hat_mse | |
| gamma = r.norm() | |
| print("\nResidual r = x - x_hat_mse:") | |
| print(r) | |
| print("\nResidual norm gamma:") | |
| print(gamma.item()) | |
| # ------------------------------------------------------------ | |
| # Step 5: QJL-like residual sketch | |
| # In the paper, S is a random Gaussian matrix. | |
| # Here we reuse Pi for a simple toy demo. | |
| # qjl = sign(S @ r) | |
| # ------------------------------------------------------------ | |
| S = Pi.clone() | |
| proj = S @ r | |
| qjl = torch.sign(proj) | |
| qjl[qjl == 0] = 1.0 | |
| print("\nProjected residual S @ r:") | |
| print(proj) | |
| print("\n1-bit sign sketch qjl:") | |
| print(qjl) | |
| # ------------------------------------------------------------ | |
| # Step 6: reconstruct residual approximately | |
| # Paper-style correction term: | |
| # r_hat = sqrt(pi/2)/d * gamma * S^T @ qjl | |
| # ------------------------------------------------------------ | |
| r_hat = math.sqrt(math.pi / 2.0) / head_dim * gamma * (S.T @ qjl) | |
| print("\nApprox residual reconstruction r_hat:") | |
| print(r_hat) | |
| # Final reconstruction | |
| x_hat_final = x_hat_mse + r_hat | |
| print("\nFinal reconstruction x_hat_final:") | |
| print(x_hat_final) | |
| # ------------------------------------------------------------ | |
| # Step 7: compare errors | |
| # ------------------------------------------------------------ | |
| mse_only = ((x - x_hat_mse) ** 2).mean().item() | |
| two_stage = ((x - x_hat_final) ** 2).mean().item() | |
| print("\nMSE-only reconstruction error:", mse_only) | |
| print("Two-stage reconstruction error:", two_stage) | |
| # ------------------------------------------------------------ | |
| # Step 8: show effect on an attention-style inner product | |
| # Pretend q is the new query vector for a future token. | |
| # ------------------------------------------------------------ | |
| q = torch.randn(head_dim) | |
| true_ip = torch.dot(q, x).item() | |
| mse_ip = torch.dot(q, x_hat_mse).item() | |
| final_ip = torch.dot(q, x_hat_final).item() | |
| print("\nQuery vector q:") | |
| print(q) | |
| print("\nTrue inner product q·x :", true_ip) | |
| print("MSE-only inner product q·xhat :", mse_ip) | |
| print("Final inner product q·xhat :", final_ip) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment