Skip to content

Instantly share code, notes, and snippets.

@whilo
Created April 17, 2025 19:18
Show Gist options
  • Save whilo/295440ff44c08b0a35bbdb1c18dd2409 to your computer and use it in GitHub Desktop.
Save whilo/295440ff44c08b0a35bbdb1c18dd2409 to your computer and use it in GitHub Desktop.
"""
Toy Free‑Form Flow (FFF) training demo on a 2D mixture of Gaussians.
Requires:
pip install matplotlib
pip install -e third_party/FFF
"""
import torch
from torch import optim
from torch.utils.data import DataLoader, Dataset
import argparse
try:
from fff.loss import fff_loss
except ImportError as e:
raise ImportError(
"Free‑Form Flow (FFF) not found. Install with: matplotlib and editable FFF\n"
"pip install matplotlib && pip install -e third_party/FFF"
) from e
import numpy as np
class Toy2DDataset(Dataset):
"""
Samples from an 8‑component Gaussian mixture arranged on a circle.
"""
def __init__(self, n_samples=1000, seed=0, cluster_std=0.1, radius=2.0):
np.random.seed(seed)
# assign each sample to one of 8 clusters
cluster_ids = np.random.randint(0, 8, size=n_samples)
angles = cluster_ids / 8 * 2 * np.pi
means = np.stack([np.cos(angles), np.sin(angles)], axis=1) * radius
samples = means + np.random.randn(n_samples, 2) * cluster_std
self.x = torch.from_numpy(samples).float()
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
return self.x[idx]
def parse_args():
parser = argparse.ArgumentParser(
description="Toy Free‑Form Flow (FFF) on a circular 8‑mode 2D Gaussian mixture"
)
# model/training hyperparameters
parser.add_argument("--hidden-units", type=int, default=256,
help="Width of hidden layers (default: 256)")
parser.add_argument("--latent-dim", type=int, default=2,
help="Dimensionality of latent space (default: 2; must equal data dim)")
parser.add_argument("--batch-size", type=int, default=128,
help="Batch size for training (default: 128)")
parser.add_argument("--lr", type=float, default=1e-3,
help="Initial learning rate (default: 1e-3)")
parser.add_argument("--weight-decay", type=float, default=1e-4,
help="Weight decay for optimizer (default: 1e-4)")
parser.add_argument("--beta", type=float, default=0.2,
help="Reconstruction loss weight (default: 0.2)")
parser.add_argument("--hutchinson-samples", type=int, default=2,
help="Number of Hutchinson samples (<= latent-dim, default: 2)")
parser.add_argument("--epochs", type=int, default=100,
help="Number of training epochs (default: 100)")
# data parameters
parser.add_argument("--n-samples", type=int, default=1000,
help="Number of samples in dataset (default: 1000)")
parser.add_argument("--seed", type=int, default=1,
help="Random seed for data (default: 1)")
parser.add_argument("--cluster-std", type=float, default=0.1,
help="Standard deviation of clusters (default: 0.1)")
parser.add_argument("--radius", type=float, default=2.0,
help="Radius of the circle (default: 2.0)")
parser.add_argument("--no-cuda", action="store_true",
help="Disable CUDA even if available")
return parser.parse_args()
def main(args):
# device
use_cuda = torch.cuda.is_available() and not args.no_cuda
device = torch.device("cuda" if use_cuda else "cpu")
# data loader
dataset = Toy2DDataset(
n_samples=args.n_samples,
seed=args.seed,
cluster_std=args.cluster_std,
radius=args.radius,
)
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
# encoder/decoder with extended capacity (3 hidden layers), latent space dim = args.latent_dim
encoder = torch.nn.Sequential(
torch.nn.Linear(2, args.hidden_units),
torch.nn.GELU(),
torch.nn.Linear(args.hidden_units, args.hidden_units),
torch.nn.GELU(),
torch.nn.Linear(args.hidden_units, args.hidden_units),
torch.nn.GELU(),
torch.nn.Linear(args.hidden_units, args.latent_dim),
).to(device)
decoder = torch.nn.Sequential(
torch.nn.Linear(args.latent_dim, args.hidden_units),
torch.nn.GELU(),
torch.nn.Linear(args.hidden_units, args.hidden_units),
torch.nn.GELU(),
torch.nn.Linear(args.hidden_units, args.hidden_units),
torch.nn.GELU(),
torch.nn.Linear(args.hidden_units, 2),
).to(device)
# latent prior
latent_loc = torch.zeros(args.latent_dim, device=device)
latent_scale = torch.ones(args.latent_dim, device=device)
latent_dist = torch.distributions.Normal(latent_loc, latent_scale)
# optimizer & scheduler
optimizer = optim.AdamW(
list(encoder.parameters()) + list(decoder.parameters()),
lr=args.lr,
weight_decay=args.weight_decay,
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=args.epochs
)
loss_history = []
encoder.train(); decoder.train()
# clamp Hutchinson samples to latent dimension
latent_dim = args.latent_dim
hutch = min(args.hutchinson_samples, latent_dim)
if args.hutchinson_samples > latent_dim:
print(f"Warning: hutchinson-samples reduced to {latent_dim}")
# training loop
for epoch in range(1, args.epochs + 1):
total = 0.0
for x in loader:
x = x.to(device)
loss_vals = fff_loss(
x, encoder, decoder, latent_dist,
beta=args.beta,
hutchinson_samples=hutch,
)
loss = loss_vals.mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
total += loss.item() * x.size(0)
avg = total / len(dataset)
loss_history.append(avg)
print(f"Epoch {epoch}/{args.epochs}, Loss: {avg:.4f}")
scheduler.step()
# plotting
try:
import matplotlib.pyplot as plt
# loss curve
plt.figure()
plt.plot(range(1, args.epochs + 1), loss_history, marker='o')
plt.xlabel('Epoch'); plt.ylabel('Avg Loss')
plt.title('FFF Toy Loss')
plt.savefig('fff_toy_loss.png')
print('Saved fff_toy_loss.png')
# samples scatter: sample in latent space and decode
encoder.eval(); decoder.eval()
with torch.no_grad():
# sample latent variables of dimension args.latent_dim
z = torch.randn(400, args.latent_dim, device=device)
samples = decoder(z).cpu().numpy()
data_np = dataset.x.numpy()
plt.figure(figsize=(5,5))
plt.scatter(data_np[:,0], data_np[:,1], s=4, alpha=0.3, label='True')
plt.scatter(samples[:,0], samples[:,1], s=4, alpha=0.6, label='Flow')
plt.legend(); plt.title('True vs Flow Samples')
plt.savefig('fff_toy_samples.png')
print('Saved fff_toy_samples.png')
except ImportError:
print('matplotlib not available, skipping plots.')
if __name__ == "__main__":
args = parse_args()
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment