Last active
January 30, 2025 19:47
-
-
Save razhangwei/63058e4081816927891023ae20287289 to your computer and use it in GitHub Desktop.
RQ-VAE pseudo code
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List, Dict, Tuple | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class RQVAE: | |
def __init__(self, | |
input_dim: int, # Shape: scalar (e.g., 768 for BERT embeddings) | |
hidden_dims: List[int], # Shape: [dim1, dim2, ..., latent_dim] | |
num_codebooks: int, # Shape: scalar (e.g., 3 for 3-level quantization) | |
codebook_size: int): # Shape: scalar (e.g., 256 vectors per codebook) | |
""" | |
Initialize RQ-VAE with specified dimensions and parameters. | |
Args: | |
input_dim: Original input dimensionality | |
hidden_dims: Dimensions of hidden layers, last one is latent dimension | |
num_codebooks: Number of quantization levels | |
codebook_size: Number of vectors in each codebook | |
""" | |
self.encoder = Encoder(input_dim, hidden_dims[-1]) | |
self.decoder = Decoder(hidden_dims[-1], input_dim) | |
self.num_codebooks = num_codebooks | |
self.codebook_size = codebook_size | |
# Initialize codebooks: List[Tensor(codebook_size, latent_dim)] | |
self.codebooks = [ | |
nn.Parameter(torch.randn(codebook_size, hidden_dims[-1])) | |
for _ in range(num_codebooks) | |
] | |
def encode(self, x: torch.Tensor) -> torch.Tensor: | |
""" | |
Convert input to latent representation. | |
Args: | |
x: Input tensor of shape (batch_size, input_dim) | |
Returns: | |
Latent representation of shape (batch_size, latent_dim) | |
""" | |
return self.encoder(x) | |
def quantize(self, | |
z: torch.Tensor # Shape: (batch_size, latent_dim) | |
) -> Tuple[torch.Tensor, List[torch.Tensor]]: | |
""" | |
Perform residual quantization on latent vector. | |
Args: | |
z: Latent vectors to quantize | |
Returns: | |
tuple: | |
- Quantized vectors of shape (batch_size, latent_dim) | |
- List of indices of shape [(batch_size,)] * num_codebooks | |
""" | |
batch_size = z.size(0) | |
residual = z # Shape: (batch_size, latent_dim) | |
quantized = torch.zeros_like(z) # Shape: (batch_size, latent_dim) | |
indices: List[torch.Tensor] = [] # List of (batch_size,) tensors | |
for codebook in self.codebooks: | |
# Shape: (batch_size, codebook_size) | |
distances = torch.cdist(residual, codebook) | |
# Shape: (batch_size,) | |
min_indices = torch.argmin(distances, dim=1) | |
indices.append(min_indices) | |
# Shape: (batch_size, latent_dim) | |
selected = codebook[min_indices] | |
quantized += selected | |
residual = residual - selected | |
return quantized, indices | |
def decode(self, quantized: torch.Tensor) -> torch.Tensor: | |
""" | |
Reconstruct input from quantized representation. | |
Args: | |
quantized: Tensor of shape (batch_size, latent_dim) | |
Returns: | |
Reconstructed input of shape (batch_size, input_dim) | |
""" | |
return self.decoder(quantized) | |
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: | |
""" | |
Forward pass through RQ-VAE. | |
Args: | |
x: Input tensor of shape (batch_size, input_dim) | |
Returns: | |
Dictionary containing: | |
- 'reconstructed': (batch_size, input_dim) | |
- 'quantized': (batch_size, latent_dim) | |
- 'indices': List of (batch_size,) tensors | |
- 'z': (batch_size, latent_dim) | |
""" | |
z = self.encode(x) | |
quantized, indices = self.quantize(z) | |
reconstructed = self.decode(quantized) | |
return { | |
'reconstructed': reconstructed, # Shape: (batch_size, input_dim) | |
'quantized': quantized, # Shape: (batch_size, latent_dim) | |
'indices': indices, # Shape: List[(batch_size,)] * num_codebooks | |
'z': z # Shape: (batch_size, latent_dim) | |
} | |
def loss_function(self, | |
x: torch.Tensor, # Shape: (batch_size, input_dim) | |
forward_output: Dict[str, torch.Tensor], | |
beta: float = 0.25 # Shape: scalar | |
) -> torch.Tensor: # Shape: scalar | |
""" | |
Compute RQ-VAE loss combining reconstruction and commitment losses. | |
Args: | |
x: Original input | |
forward_output: Output dictionary from forward pass | |
beta: Commitment loss weight | |
Returns: | |
Total loss value (scalar) | |
""" | |
recon = forward_output['reconstructed'] | |
quantized = forward_output['quantized'] | |
z = forward_output['z'] | |
# Both losses reduce to scalar values | |
recon_loss = F.mse_loss(recon, x) | |
commit_loss = F.mse_loss(quantized.detach(), z) + \ | |
beta * F.mse_loss(quantized, z.detach()) | |
return recon_loss + commit_loss | |
# Example usage shapes: | |
# batch_size, input_dim, latent_dim = 32, 768, 64 | |
# x = torch.randn(batch_size, input_dim) | |
# rqvae = RQVAE(input_dim=input_dim, | |
# hidden_dims=[512, 256, latent_dim], | |
# num_codebooks=3, | |
# codebook_size=256) | |
# output = rqvae(x) | |
# loss = rqvae.loss_function(x, output) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment