Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Last active February 20, 2025 16:37
Show Gist options
  • Save crowsonkb/4c10e6f2f616f67b68d046490d02696c to your computer and use it in GitHub Desktop.
Save crowsonkb/4c10e6f2f616f67b68d046490d02696c to your computer and use it in GitHub Desktop.
Lookup Free Quantization (LFQ) for PyTorch.
"""Lookup Free Quantization (LFQ) for PyTorch."""
from dataclasses import dataclass
from itertools import product
import math
from typing import Optional
import torch
from torch import distributed as dist, nn
from torch.distributed import nn as dnn
from torch.nn import functional as F
@dataclass
class LFQReturn:
"""Return type for LFQ."""
indices: torch.Tensor
quantized: torch.Tensor
loss: Optional[torch.Tensor]
class LFQ(nn.Module):
"""Lookup Free Quantization (LFQ) for PyTorch."""
def __init__(self, embedding_dim: int):
super().__init__()
self.num_embeddings = 2**embedding_dim
self.embedding_dim = embedding_dim
weight = torch.tensor(list(product((-1.0, 1.0), repeat=embedding_dim)))
self.register_buffer("weight", weight)
self.register_buffer("place_values", 2 ** torch.arange(embedding_dim).flip(-1))
def extra_repr(self) -> str:
return f"embedding_dim={self.embedding_dim}"
def quantize(self, x: torch.Tensor) -> torch.Tensor:
return (x > 0).to(x.dtype) * 2.0 - 1.0
def indices_to_latents(self, indices: torch.Tensor) -> torch.Tensor:
return self.weight[indices]
def latents_to_indices(self, latents: torch.Tensor) -> torch.Tensor:
return torch.sum(self.place_values * (latents > 0), dim=-1)
def forward(
self, x: torch.Tensor, distributed: bool = False, group: Optional[dist.ProcessGroup] = None
) -> LFQReturn:
quantized = self.quantize(x)
indices = self.latents_to_indices(quantized)
loss = None
if self.training:
# this gradient estimator is different from the LFQ paper but seems to work better.
# it's the same as `torch.softmax(logits, dim=-1) @ self.weight` (the straight through
# softmax estimator) but more efficient to compute.
quantized_soft = torch.tanh(x)
quantized = quantized + (quantized_soft - quantized_soft.detach())
# probabilities are proportional to exp(-||x - embedding||^2 / 2).
logits = x @ self.weight.T - torch.sum(x**2, dim=-1, keepdim=True) / 2
logp = F.log_softmax(logits, dim=-1)
# loss is the negative Jensen-Shannon divergence between the distributions for each
# token, plus a constant to make the loss always positive.
logp_sum = torch.logsumexp(logp.flatten(0, -2), dim=0)
n = logp_sum.new_tensor(logp.flatten(0, -2).shape[0])
if distributed:
logp_sum = torch.logsumexp(torch.stack(dnn.all_gather(logp_sum, group)), dim=0)
dist.all_reduce(n, dist.ReduceOp.AVG, group)
logp_mixture = F.log_softmax(logp_sum, dim=-1)
entr_components = -torch.sum(torch.exp(logp) * logp) / n
entr_mixture = -torch.sum(torch.exp(logp_mixture) * logp_mixture)
loss = math.log(self.num_embeddings) + entr_components - entr_mixture
return LFQReturn(indices, quantized, loss)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment