Created
March 23, 2024 23:30
-
-
Save catid/977de9dcd74e05d91e29662001439f0c to your computer and use it in GitHub Desktop.
GIVT GMM Decoder (Claude 3)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Collaboration between Claude-3 and GPT-4 to implement https://arxiv.org/pdf/2312.02116.pdf | |
# This is just the GMM decoder part of the model they propose (which is the new thing). | |
# This one was mainly generated by Claude-3. | |
# The AIs provided two implementations of the idea and revised eachothers' code. | |
# I tested that the unit tests pass but haven't tried it in a language model yet. | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class ImprovedGMMParametersPrediction(nn.Module): | |
def __init__(self, hidden_dim, output_dim, num_components): | |
super().__init__() | |
self.num_components = num_components | |
# Initialize parameters prediction layers | |
self.mu = nn.Linear(hidden_dim, output_dim * num_components) | |
self.log_sigma = nn.Linear(hidden_dim, output_dim * num_components) | |
self.logits_pi = nn.Linear(hidden_dim, num_components) | |
def forward(self, x): | |
batch_size, seq_length, _ = x.size() | |
# Predict parameters and reshape appropriately | |
mu = self.mu(x).view(batch_size, seq_length, self.num_components, -1) | |
log_sigma = self.log_sigma(x).view(batch_size, seq_length, self.num_components, -1) | |
logits_pi = self.logits_pi(x).view(batch_size, seq_length, self.num_components) | |
# Softmax for mixing coefficients and exp for standard deviations | |
pi = F.softmax(logits_pi, dim=-1) | |
sigma = torch.exp(log_sigma) | |
return mu, sigma, pi | |
class ImprovedGMMOutput(nn.Module): | |
def __init__(self, hidden_dim, output_dim, num_mixtures): | |
super().__init__() | |
self.num_mixtures = num_mixtures | |
# Parameter prediction layers | |
self.fc_means = nn.Linear(hidden_dim, output_dim * num_mixtures) | |
self.fc_scales = nn.Linear(hidden_dim, output_dim * num_mixtures) | |
self.fc_weights = nn.Linear(hidden_dim, num_mixtures) | |
def forward(self, x): | |
batch_size, seq_length, _ = x.size() | |
# Predict means, scales, and weights for the GMM | |
means = self.fc_means(x).view(batch_size, seq_length, self.num_mixtures, -1) | |
scales = F.softplus(self.fc_scales(x)).view(batch_size, seq_length, self.num_mixtures, -1) | |
weights = F.softmax(self.fc_weights(x), dim=-1) | |
return means, scales, weights | |
# Unit tests | |
from torch.testing import assert_allclose | |
def test_improved_gmm_parameter_prediction(): | |
batch_size = 2 | |
seq_length = 3 | |
hidden_dim = 4 | |
output_dim = 5 | |
num_components = 3 | |
model = ImprovedGMMParametersPrediction(hidden_dim, output_dim, num_components) | |
x = torch.randn(batch_size, seq_length, hidden_dim) | |
mu, sigma, pi = model(x) | |
assert mu.shape == (batch_size, seq_length, num_components, output_dim) | |
assert sigma.shape == (batch_size, seq_length, num_components, output_dim) | |
assert pi.shape == (batch_size, seq_length, num_components) | |
assert torch.allclose(pi.sum(dim=-1), torch.ones(batch_size, seq_length), rtol=1e-5, atol=1e-8) | |
assert (sigma > 0).all() | |
def test_improved_gmm_output(): | |
batch_size = 2 | |
seq_length = 3 | |
hidden_dim = 4 | |
output_dim = 5 | |
num_mixtures = 3 | |
model = ImprovedGMMOutput(hidden_dim, output_dim, num_mixtures) | |
x = torch.randn(batch_size, seq_length, hidden_dim) | |
means, scales, weights = model(x) | |
assert means.shape == (batch_size, seq_length, num_mixtures, output_dim) | |
assert scales.shape == (batch_size, seq_length, num_mixtures, output_dim) | |
assert weights.shape == (batch_size, seq_length, num_mixtures) | |
assert torch.allclose(weights.sum(dim=-1), torch.ones(batch_size, seq_length), rtol=1e-5, atol=1e-8) | |
assert (scales > 0).all() | |
if __name__ == "__main__": | |
test_improved_gmm_parameter_prediction() | |
test_improved_gmm_output() | |
print("All tests passed!") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment