Created
September 2, 2024 17:41
-
-
Save razhangwei/d20f4082533b213b509d70d90be16a8a to your computer and use it in GitHub Desktop.
NormalFloat 4 Quantization #pytorch #quantization
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class NF4Quantizer(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.nf4_values = torch.tensor([ | |
-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, | |
-0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, | |
0.07958029955625534, 0.16093020141124725, 0.24611230194568634, | |
0.33791524171829224, 0.44070982933044434, 0.5626170039176941, | |
0.7229568362236023, 1.0 | |
]) | |
def forward(self, x): | |
x_scaled = x / x.abs().max() | |
x_quantized = torch.argmin(torch.abs(x_scaled.unsqueeze(-1) - self.nf4_values), dim=-1) | |
x_dequantized = F.embedding(x_quantized, self.nf4_values) # good trick | |
return x_dequantized * x.abs().max() | |
class NF4QuantizedModule(nn.Module): | |
def __init__(self, module): | |
super().__init__() | |
self.module = module | |
self.quantizer = NF4Quantizer() | |
def forward(self, *args, **kwargs): | |
# Quantize weights before forward pass | |
original_weights = {} | |
for name, param in self.module.named_parameters(): | |
if 'weight' in name: | |
original_weights[name] = param.data | |
param.data = self.quantizer(param.data) | |
# Perform forward pass | |
output = self.module(*args, **kwargs) | |
# Restore original weights | |
for name, param in self.module.named_parameters(): | |
if 'weight' in name: | |
param.data = original_weights[name] | |
return output | |
# Example usage | |
model = nn.Sequential( | |
nn.Linear(10, 20), | |
nn.ReLU(), | |
nn.Linear(20, 5) | |
) | |
# Wrap the entire model with NF4 quantization | |
quantized_model = NF4QuantizedModule(model) | |
# Create a random input | |
x = torch.randn(3, 10) | |
# Forward pass through the quantized model | |
output = quantized_model(x) | |
print("Input shape:", x.shape) | |
print("Output shape:", output.shape) | |
# Compare original and quantized weights | |
for name, param in model.named_parameters(): | |
if 'weight' in name: | |
print(f"\nLayer: {name}") | |
print("Original weight range:", param.min().item(), "to", param.max().item()) | |
print("Quantized weight range:", quantized_model.quantizer(param).min().item(), "to", | |
quantized_model.quantizer(param).max().item()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment