Skip to content

Instantly share code, notes, and snippets.

@vwxyzjn
Last active January 31, 2025 17:31
Show Gist options
  • Save vwxyzjn/ab8e6c010ea589a0b57cdfb70604c810 to your computer and use it in GitHub Desktop.
Save vwxyzjn/ab8e6c010ea589a0b57cdfb70604c810 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.optim as optim
# Create target distribution (fixed)
target_logits = torch.randn(10)
target_log_probs = torch.log_softmax(target_logits, dim=0)
# Create learnable distribution
learnable_logits = nn.Parameter(torch.rand_like(target_logits)) # Initialize randomly
# Setup optimizer
optimizer = optim.Adam([learnable_logits], lr=0.1)
# Training loop
for step in range(40):
# Get current log probabilities
current_log_probs = torch.log_softmax(learnable_logits, dim=0)
# k1 estimator: -log(r) = -log(p/q) = log(q) - log(p)
# Note: Since we're sampling from q, we just take the direct difference
kl_loss = torch.mean(current_log_probs - target_log_probs)
# Backward pass and optimize
optimizer.zero_grad()
kl_loss.backward()
optimizer.step()
if step % 5 == 0:
print(f"Step {step}, KL Loss: {kl_loss.item():.4f}")
print(f"Current distribution: {torch.softmax(learnable_logits, dim=0).detach().numpy()}")
print("\nFinal distributions w/ using kl1 estimator:")
print(f"Target: {torch.softmax(target_logits, dim=0).numpy()}")
print(f"Learned: {torch.softmax(learnable_logits, dim=0).detach().numpy()}")
import torch
import torch.nn as nn
import torch.optim as optim
# Create target distribution (fixed)
target_logits = torch.randn(10)
target_log_probs = torch.log_softmax(target_logits, dim=0)
# Create learnable distribution
learnable_logits = nn.Parameter(torch.rand_like(target_logits)) # Initialize randomly
# Setup optimizer
optimizer = optim.Adam([learnable_logits], lr=0.1)
# Training loop
for step in range(40):
# Get current log probabilities
current_log_probs = torch.log_softmax(learnable_logits, dim=0)
# Calculate log ratio: log(p(x)/q(x)) = log p(x) - log q(x)
log_ratio = target_log_probs - current_log_probs
# Calculate k2 estimator: 1/2 * (log r)^2
kl_loss = torch.mean(0.5 * log_ratio.pow(2))
# Backward pass and optimize
optimizer.zero_grad()
kl_loss.backward()
optimizer.step()
if step % 5 == 0:
print(f"Step {step}, KL Loss: {kl_loss.item():.4f}")
print(f"Current distribution: {torch.softmax(learnable_logits, dim=0).detach().numpy()}")
print("\nFinal distributions using kl2 estimator:")
print(f"Target: {torch.softmax(target_logits, dim=0).numpy()}")
print(f"Learned: {torch.softmax(learnable_logits, dim=0).detach().numpy()}")
import torch
import torch.nn as nn
import torch.optim as optim
# Create target distribution (fixed)
target_logits = torch.randn(10)
target_log_probs = torch.log_softmax(target_logits, dim=0)
# Create learnable distribution
learnable_logits = nn.Parameter(torch.rand_like(target_logits)) # Initialize randomly
# Setup optimizer
optimizer = optim.Adam([learnable_logits], lr=0.1)
# Training loop
for step in range(40):
# Get current log probabilities
current_log_probs = torch.log_softmax(learnable_logits, dim=0)
# Calculate log ratio: log(p(x)/q(x)) = log p(x) - log q(x)
log_ratio = target_log_probs - current_log_probs
# Calculate k3 estimator: (r - 1) - log(r)
# Where r = exp(log_ratio)
ratio = torch.exp(log_ratio)
kl_loss = torch.mean((ratio - 1) - log_ratio)
# Backward pass and optimize
optimizer.zero_grad()
kl_loss.backward()
optimizer.step()
if step % 5 == 0:
print(f"Step {step}, KL Loss: {kl_loss.item():.4f}")
print(f"Current distribution: {torch.softmax(learnable_logits, dim=0).detach().numpy()}")
print("\nFinal distributions:")
print(f"Target: {torch.softmax(target_logits, dim=0).numpy()}")
print(f"Learned: {torch.softmax(learnable_logits, dim=0).detach().numpy()}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment