Last active
December 29, 2023 09:30
-
-
Save Lawrencium77/5a21217086ca00876b369e05baa6cd43 to your computer and use it in GitHub Desktop.
Simple SmoothQuant Implementation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
SmoothQuant implementation. See: https://arxiv.org/pdf/2211.10438.pdf | |
Some details are model-specific, so the code may need tweaking. | |
""" | |
import functools | |
import torch | |
from torch import nn, Tensor | |
from typing import Dict, Iterable, Tuple | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
##################################################################################################### | |
# Compute max statistics for activations and weights | |
def get_tensor_max(weight: Tensor) -> Tensor: | |
"""Compute row-wise max of a 2D tensor""" | |
return torch.max(weight.abs(), dim=0).values | |
def combine_qkv_maxs(model: nn.Module, maxs: Dict[str, torch.Tensor]) -> None: | |
"""Takes the max of the Q/K/V weight maxs for each layer""" | |
# Group Q/K/V weight maxs by layer | |
qkv_names = [] | |
for i in range(0, model.n_layers): | |
names = [] | |
for name in maxs.keys(): | |
layer_num = name.split(".")[2] | |
if int(layer_num) == i: | |
if any(n in name for n in ["query", "key", "value"]): | |
names.append(name) | |
qkv_names.append(names) | |
# Combine Q/K/V weight maxs for each layer | |
for qkv_name in qkv_names: | |
q_name, k_name, v_name = qkv_name | |
q, k, v = maxs[q_name], maxs[k_name], maxs[v_name] | |
max_tensor = torch.max(torch.max(q, k), v) | |
maxs[q_name] = maxs[k_name] = maxs[v_name] = max_tensor | |
def get_weight_maxs(model: nn.Module) -> Dict[str, torch.Tensor]: | |
"""Calculate row-wise max of weights for Linear layers following a LayerNorm""" | |
weight_maxs = {} | |
module_names = ["query", "key", "value", "ff.0"] | |
for name, module in model.named_modules(): | |
if any(n in name for n in module_names) and isinstance(module, nn.Linear): | |
weight_maxs[name] = get_tensor_max(module.weight) | |
combine_qkv_maxs(model, weight_maxs) | |
return weight_maxs | |
def get_actv_maxs(model: nn.Module, datastream: Iterable, num_batches: int = 10) -> Dict[str, torch.Tensor]: | |
"""Calculate channel-wise max activations for LayerNorm layers""" | |
actv_maxs = {} | |
def update_actv_stats(name: str, input: Tensor) -> None: | |
"""Update max value""" | |
x = torch.flatten(input, start_dim=0, end_dim=1) | |
maxs = torch.max(x.abs(), dim=0).values | |
if name in actv_maxs: | |
actv_maxs[name] = torch.max(actv_maxs[name], maxs) | |
else: | |
actv_maxs[name] = maxs | |
def actv_max_hook(model: nn.Module, input: Tuple[torch.Tensor, ...], output: Tensor, name: str) -> None: | |
"""Forward hook""" | |
if isinstance(output, tuple): | |
output = output[0] | |
update_actv_stats(name, output) | |
# Register hooks | |
hooks = [] | |
for name, module in model.named_modules(): | |
if isinstance(module, nn.LayerNorm): | |
hooks.append(module.register_forward_hook(functools.partial(actv_max_hook, name=name))) | |
# Collect activation stats on a few batches | |
for i, batch in enumerate(datastream): | |
data = batch[0] | |
model(data.to(device), chunk_size=20000) | |
if i >= num_batches: | |
break | |
# Remove hooks | |
for h in hooks: | |
h.remove() | |
return actv_maxs | |
##################################################################################################### | |
# Apply smoothing | |
def smoothing_from_maxs(actv_maxs, weight_maxs, alpha=0.75): | |
"""Calculate smoothing factors for individual weight & actv max tensor""" | |
return (actv_maxs ** alpha) / (weight_maxs ** (1 - alpha)) | |
def get_smoothing_factors(actv_maxs: Tensor, weight_maxs: Tensor, alpha: float = 0.75) -> Tensor: | |
"""Calculate smoothing factors, given actv & weight maxs""" | |
smoothing_factors = {} | |
for actv_name, actv_max in actv_maxs.items(): | |
actv_layer_num = actv_name.split(".")[2] | |
for weight_name, weight_max in weight_maxs.items(): | |
weight_layer_num = weight_name.split(".")[2] | |
if actv_layer_num == weight_layer_num: | |
if all("attention" in name for name in [actv_name, weight_name]): | |
smoothing_factors[actv_name] = smoothing_from_maxs(actv_max, weight_max, alpha) ** -1 | |
smoothing_factors[weight_name] = smoothing_from_maxs(actv_max, weight_max, alpha) | |
elif all("feed_forward" in name for name in [actv_name, weight_name]): | |
smoothing_factors[actv_name] = smoothing_from_maxs(actv_max, weight_max, alpha) ** -1 | |
smoothing_factors[weight_name] = smoothing_from_maxs(actv_max, weight_max, alpha) | |
return smoothing_factors | |
def smooth_parameters(model: nn.Module, smoothing_factors: Dict[str, Tensor]) -> None: | |
"""Apply smoothing to model parameters""" | |
for name, module in model.named_modules(): | |
if name in smoothing_factors.keys(): | |
factor = smoothing_factors[name] | |
module.weight = nn.Parameter(module.weight * factor) | |
if module.bias is not None: | |
module.bias = nn.Parameter(module.bias * factor) | |
##################################################################################################### | |
# Wrapper function | |
def apply_smoothquant(model: nn.Module, datastream: Iterable, alpha: float = 0.75) -> None: | |
""" | |
Applies SmoothQuant to the given model using the provided datastream. | |
Parameters: | |
- model: PyTorch model to which SmoothQuant will be applied. | |
- datastream: An iterable (e.g., DataLoader) that yields batches of input data for the model. | |
- alpha: The balance factor between activations and weights, as described in the paper. | |
""" | |
actv_maxs = get_actv_maxs(model, datastream=datastream, num_batches=10) | |
weight_maxs = get_weight_maxs(model) | |
smoothing_factors = get_smoothing_factors(actv_maxs, weight_maxs, alpha) | |
smooth_parameters(model, smoothing_factors) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment