Skip to content

Instantly share code, notes, and snippets.

@razhangwei
Created September 2, 2024 17:41
Show Gist options
  • Save razhangwei/d20f4082533b213b509d70d90be16a8a to your computer and use it in GitHub Desktop.
Save razhangwei/d20f4082533b213b509d70d90be16a8a to your computer and use it in GitHub Desktop.
NormalFloat 4 Quantization #pytorch #quantization
import torch
import torch.nn as nn
import torch.nn.functional as F
class NF4Quantizer(nn.Module):
def __init__(self):
super().__init__()
self.nf4_values = torch.tensor([
-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453,
-0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0,
0.07958029955625534, 0.16093020141124725, 0.24611230194568634,
0.33791524171829224, 0.44070982933044434, 0.5626170039176941,
0.7229568362236023, 1.0
])
def forward(self, x):
x_scaled = x / x.abs().max()
x_quantized = torch.argmin(torch.abs(x_scaled.unsqueeze(-1) - self.nf4_values), dim=-1)
x_dequantized = F.embedding(x_quantized, self.nf4_values) # good trick
return x_dequantized * x.abs().max()
class NF4QuantizedModule(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
self.quantizer = NF4Quantizer()
def forward(self, *args, **kwargs):
# Quantize weights before forward pass
original_weights = {}
for name, param in self.module.named_parameters():
if 'weight' in name:
original_weights[name] = param.data
param.data = self.quantizer(param.data)
# Perform forward pass
output = self.module(*args, **kwargs)
# Restore original weights
for name, param in self.module.named_parameters():
if 'weight' in name:
param.data = original_weights[name]
return output
# Example usage
model = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 5)
)
# Wrap the entire model with NF4 quantization
quantized_model = NF4QuantizedModule(model)
# Create a random input
x = torch.randn(3, 10)
# Forward pass through the quantized model
output = quantized_model(x)
print("Input shape:", x.shape)
print("Output shape:", output.shape)
# Compare original and quantized weights
for name, param in model.named_parameters():
if 'weight' in name:
print(f"\nLayer: {name}")
print("Original weight range:", param.min().item(), "to", param.max().item())
print("Quantized weight range:", quantized_model.quantizer(param).min().item(), "to",
quantized_model.quantizer(param).max().item())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment