Skip to content

Instantly share code, notes, and snippets.

@razhangwei
Last active January 30, 2025 19:47
Show Gist options
  • Save razhangwei/63058e4081816927891023ae20287289 to your computer and use it in GitHub Desktop.
Save razhangwei/63058e4081816927891023ae20287289 to your computer and use it in GitHub Desktop.
RQ-VAE pseudo code
from typing import List, Dict, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
class RQVAE:
def __init__(self,
input_dim: int, # Shape: scalar (e.g., 768 for BERT embeddings)
hidden_dims: List[int], # Shape: [dim1, dim2, ..., latent_dim]
num_codebooks: int, # Shape: scalar (e.g., 3 for 3-level quantization)
codebook_size: int): # Shape: scalar (e.g., 256 vectors per codebook)
"""
Initialize RQ-VAE with specified dimensions and parameters.
Args:
input_dim: Original input dimensionality
hidden_dims: Dimensions of hidden layers, last one is latent dimension
num_codebooks: Number of quantization levels
codebook_size: Number of vectors in each codebook
"""
self.encoder = Encoder(input_dim, hidden_dims[-1])
self.decoder = Decoder(hidden_dims[-1], input_dim)
self.num_codebooks = num_codebooks
self.codebook_size = codebook_size
# Initialize codebooks: List[Tensor(codebook_size, latent_dim)]
self.codebooks = [
nn.Parameter(torch.randn(codebook_size, hidden_dims[-1]))
for _ in range(num_codebooks)
]
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""
Convert input to latent representation.
Args:
x: Input tensor of shape (batch_size, input_dim)
Returns:
Latent representation of shape (batch_size, latent_dim)
"""
return self.encoder(x)
def quantize(self,
z: torch.Tensor # Shape: (batch_size, latent_dim)
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""
Perform residual quantization on latent vector.
Args:
z: Latent vectors to quantize
Returns:
tuple:
- Quantized vectors of shape (batch_size, latent_dim)
- List of indices of shape [(batch_size,)] * num_codebooks
"""
batch_size = z.size(0)
residual = z # Shape: (batch_size, latent_dim)
quantized = torch.zeros_like(z) # Shape: (batch_size, latent_dim)
indices: List[torch.Tensor] = [] # List of (batch_size,) tensors
for codebook in self.codebooks:
# Shape: (batch_size, codebook_size)
distances = torch.cdist(residual, codebook)
# Shape: (batch_size,)
min_indices = torch.argmin(distances, dim=1)
indices.append(min_indices)
# Shape: (batch_size, latent_dim)
selected = codebook[min_indices]
quantized += selected
residual = residual - selected
return quantized, indices
def decode(self, quantized: torch.Tensor) -> torch.Tensor:
"""
Reconstruct input from quantized representation.
Args:
quantized: Tensor of shape (batch_size, latent_dim)
Returns:
Reconstructed input of shape (batch_size, input_dim)
"""
return self.decoder(quantized)
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Forward pass through RQ-VAE.
Args:
x: Input tensor of shape (batch_size, input_dim)
Returns:
Dictionary containing:
- 'reconstructed': (batch_size, input_dim)
- 'quantized': (batch_size, latent_dim)
- 'indices': List of (batch_size,) tensors
- 'z': (batch_size, latent_dim)
"""
z = self.encode(x)
quantized, indices = self.quantize(z)
reconstructed = self.decode(quantized)
return {
'reconstructed': reconstructed, # Shape: (batch_size, input_dim)
'quantized': quantized, # Shape: (batch_size, latent_dim)
'indices': indices, # Shape: List[(batch_size,)] * num_codebooks
'z': z # Shape: (batch_size, latent_dim)
}
def loss_function(self,
x: torch.Tensor, # Shape: (batch_size, input_dim)
forward_output: Dict[str, torch.Tensor],
beta: float = 0.25 # Shape: scalar
) -> torch.Tensor: # Shape: scalar
"""
Compute RQ-VAE loss combining reconstruction and commitment losses.
Args:
x: Original input
forward_output: Output dictionary from forward pass
beta: Commitment loss weight
Returns:
Total loss value (scalar)
"""
recon = forward_output['reconstructed']
quantized = forward_output['quantized']
z = forward_output['z']
# Both losses reduce to scalar values
recon_loss = F.mse_loss(recon, x)
commit_loss = F.mse_loss(quantized.detach(), z) + \
beta * F.mse_loss(quantized, z.detach())
return recon_loss + commit_loss
# Example usage shapes:
# batch_size, input_dim, latent_dim = 32, 768, 64
# x = torch.randn(batch_size, input_dim)
# rqvae = RQVAE(input_dim=input_dim,
# hidden_dims=[512, 256, latent_dim],
# num_codebooks=3,
# codebook_size=256)
# output = rqvae(x)
# loss = rqvae.loss_function(x, output)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment