Last active
February 20, 2025 16:37
-
-
Save crowsonkb/4c10e6f2f616f67b68d046490d02696c to your computer and use it in GitHub Desktop.
Lookup Free Quantization (LFQ) for PyTorch.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Lookup Free Quantization (LFQ) for PyTorch.""" | |
from dataclasses import dataclass | |
from itertools import product | |
import math | |
from typing import Optional | |
import torch | |
from torch import distributed as dist, nn | |
from torch.distributed import nn as dnn | |
from torch.nn import functional as F | |
@dataclass | |
class LFQReturn: | |
"""Return type for LFQ.""" | |
indices: torch.Tensor | |
quantized: torch.Tensor | |
loss: Optional[torch.Tensor] | |
class LFQ(nn.Module): | |
"""Lookup Free Quantization (LFQ) for PyTorch.""" | |
def __init__(self, embedding_dim: int): | |
super().__init__() | |
self.num_embeddings = 2**embedding_dim | |
self.embedding_dim = embedding_dim | |
weight = torch.tensor(list(product((-1.0, 1.0), repeat=embedding_dim))) | |
self.register_buffer("weight", weight) | |
self.register_buffer("place_values", 2 ** torch.arange(embedding_dim).flip(-1)) | |
def extra_repr(self) -> str: | |
return f"embedding_dim={self.embedding_dim}" | |
def quantize(self, x: torch.Tensor) -> torch.Tensor: | |
return (x > 0).to(x.dtype) * 2.0 - 1.0 | |
def indices_to_latents(self, indices: torch.Tensor) -> torch.Tensor: | |
return self.weight[indices] | |
def latents_to_indices(self, latents: torch.Tensor) -> torch.Tensor: | |
return torch.sum(self.place_values * (latents > 0), dim=-1) | |
def forward( | |
self, x: torch.Tensor, distributed: bool = False, group: Optional[dist.ProcessGroup] = None | |
) -> LFQReturn: | |
quantized = self.quantize(x) | |
indices = self.latents_to_indices(quantized) | |
loss = None | |
if self.training: | |
# this gradient estimator is different from the LFQ paper but seems to work better. | |
# it's the same as `torch.softmax(logits, dim=-1) @ self.weight` (the straight through | |
# softmax estimator) but more efficient to compute. | |
quantized_soft = torch.tanh(x) | |
quantized = quantized + (quantized_soft - quantized_soft.detach()) | |
# probabilities are proportional to exp(-||x - embedding||^2 / 2). | |
logits = x @ self.weight.T - torch.sum(x**2, dim=-1, keepdim=True) / 2 | |
logp = F.log_softmax(logits, dim=-1) | |
# loss is the negative Jensen-Shannon divergence between the distributions for each | |
# token, plus a constant to make the loss always positive. | |
logp_sum = torch.logsumexp(logp.flatten(0, -2), dim=0) | |
n = logp_sum.new_tensor(logp.flatten(0, -2).shape[0]) | |
if distributed: | |
logp_sum = torch.logsumexp(torch.stack(dnn.all_gather(logp_sum, group)), dim=0) | |
dist.all_reduce(n, dist.ReduceOp.AVG, group) | |
logp_mixture = F.log_softmax(logp_sum, dim=-1) | |
entr_components = -torch.sum(torch.exp(logp) * logp) / n | |
entr_mixture = -torch.sum(torch.exp(logp_mixture) * logp_mixture) | |
loss = math.log(self.num_embeddings) + entr_components - entr_mixture | |
return LFQReturn(indices, quantized, loss) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment