Created
April 17, 2025 19:18
-
-
Save whilo/295440ff44c08b0a35bbdb1c18dd2409 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Toy Free‑Form Flow (FFF) training demo on a 2D mixture of Gaussians. | |
Requires: | |
pip install matplotlib | |
pip install -e third_party/FFF | |
""" | |
import torch | |
from torch import optim | |
from torch.utils.data import DataLoader, Dataset | |
import argparse | |
try: | |
from fff.loss import fff_loss | |
except ImportError as e: | |
raise ImportError( | |
"Free‑Form Flow (FFF) not found. Install with: matplotlib and editable FFF\n" | |
"pip install matplotlib && pip install -e third_party/FFF" | |
) from e | |
import numpy as np | |
class Toy2DDataset(Dataset): | |
""" | |
Samples from an 8‑component Gaussian mixture arranged on a circle. | |
""" | |
def __init__(self, n_samples=1000, seed=0, cluster_std=0.1, radius=2.0): | |
np.random.seed(seed) | |
# assign each sample to one of 8 clusters | |
cluster_ids = np.random.randint(0, 8, size=n_samples) | |
angles = cluster_ids / 8 * 2 * np.pi | |
means = np.stack([np.cos(angles), np.sin(angles)], axis=1) * radius | |
samples = means + np.random.randn(n_samples, 2) * cluster_std | |
self.x = torch.from_numpy(samples).float() | |
def __len__(self): | |
return len(self.x) | |
def __getitem__(self, idx): | |
return self.x[idx] | |
def parse_args(): | |
parser = argparse.ArgumentParser( | |
description="Toy Free‑Form Flow (FFF) on a circular 8‑mode 2D Gaussian mixture" | |
) | |
# model/training hyperparameters | |
parser.add_argument("--hidden-units", type=int, default=256, | |
help="Width of hidden layers (default: 256)") | |
parser.add_argument("--latent-dim", type=int, default=2, | |
help="Dimensionality of latent space (default: 2; must equal data dim)") | |
parser.add_argument("--batch-size", type=int, default=128, | |
help="Batch size for training (default: 128)") | |
parser.add_argument("--lr", type=float, default=1e-3, | |
help="Initial learning rate (default: 1e-3)") | |
parser.add_argument("--weight-decay", type=float, default=1e-4, | |
help="Weight decay for optimizer (default: 1e-4)") | |
parser.add_argument("--beta", type=float, default=0.2, | |
help="Reconstruction loss weight (default: 0.2)") | |
parser.add_argument("--hutchinson-samples", type=int, default=2, | |
help="Number of Hutchinson samples (<= latent-dim, default: 2)") | |
parser.add_argument("--epochs", type=int, default=100, | |
help="Number of training epochs (default: 100)") | |
# data parameters | |
parser.add_argument("--n-samples", type=int, default=1000, | |
help="Number of samples in dataset (default: 1000)") | |
parser.add_argument("--seed", type=int, default=1, | |
help="Random seed for data (default: 1)") | |
parser.add_argument("--cluster-std", type=float, default=0.1, | |
help="Standard deviation of clusters (default: 0.1)") | |
parser.add_argument("--radius", type=float, default=2.0, | |
help="Radius of the circle (default: 2.0)") | |
parser.add_argument("--no-cuda", action="store_true", | |
help="Disable CUDA even if available") | |
return parser.parse_args() | |
def main(args): | |
# device | |
use_cuda = torch.cuda.is_available() and not args.no_cuda | |
device = torch.device("cuda" if use_cuda else "cpu") | |
# data loader | |
dataset = Toy2DDataset( | |
n_samples=args.n_samples, | |
seed=args.seed, | |
cluster_std=args.cluster_std, | |
radius=args.radius, | |
) | |
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) | |
# encoder/decoder with extended capacity (3 hidden layers), latent space dim = args.latent_dim | |
encoder = torch.nn.Sequential( | |
torch.nn.Linear(2, args.hidden_units), | |
torch.nn.GELU(), | |
torch.nn.Linear(args.hidden_units, args.hidden_units), | |
torch.nn.GELU(), | |
torch.nn.Linear(args.hidden_units, args.hidden_units), | |
torch.nn.GELU(), | |
torch.nn.Linear(args.hidden_units, args.latent_dim), | |
).to(device) | |
decoder = torch.nn.Sequential( | |
torch.nn.Linear(args.latent_dim, args.hidden_units), | |
torch.nn.GELU(), | |
torch.nn.Linear(args.hidden_units, args.hidden_units), | |
torch.nn.GELU(), | |
torch.nn.Linear(args.hidden_units, args.hidden_units), | |
torch.nn.GELU(), | |
torch.nn.Linear(args.hidden_units, 2), | |
).to(device) | |
# latent prior | |
latent_loc = torch.zeros(args.latent_dim, device=device) | |
latent_scale = torch.ones(args.latent_dim, device=device) | |
latent_dist = torch.distributions.Normal(latent_loc, latent_scale) | |
# optimizer & scheduler | |
optimizer = optim.AdamW( | |
list(encoder.parameters()) + list(decoder.parameters()), | |
lr=args.lr, | |
weight_decay=args.weight_decay, | |
) | |
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( | |
optimizer, T_max=args.epochs | |
) | |
loss_history = [] | |
encoder.train(); decoder.train() | |
# clamp Hutchinson samples to latent dimension | |
latent_dim = args.latent_dim | |
hutch = min(args.hutchinson_samples, latent_dim) | |
if args.hutchinson_samples > latent_dim: | |
print(f"Warning: hutchinson-samples reduced to {latent_dim}") | |
# training loop | |
for epoch in range(1, args.epochs + 1): | |
total = 0.0 | |
for x in loader: | |
x = x.to(device) | |
loss_vals = fff_loss( | |
x, encoder, decoder, latent_dist, | |
beta=args.beta, | |
hutchinson_samples=hutch, | |
) | |
loss = loss_vals.mean() | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
total += loss.item() * x.size(0) | |
avg = total / len(dataset) | |
loss_history.append(avg) | |
print(f"Epoch {epoch}/{args.epochs}, Loss: {avg:.4f}") | |
scheduler.step() | |
# plotting | |
try: | |
import matplotlib.pyplot as plt | |
# loss curve | |
plt.figure() | |
plt.plot(range(1, args.epochs + 1), loss_history, marker='o') | |
plt.xlabel('Epoch'); plt.ylabel('Avg Loss') | |
plt.title('FFF Toy Loss') | |
plt.savefig('fff_toy_loss.png') | |
print('Saved fff_toy_loss.png') | |
# samples scatter: sample in latent space and decode | |
encoder.eval(); decoder.eval() | |
with torch.no_grad(): | |
# sample latent variables of dimension args.latent_dim | |
z = torch.randn(400, args.latent_dim, device=device) | |
samples = decoder(z).cpu().numpy() | |
data_np = dataset.x.numpy() | |
plt.figure(figsize=(5,5)) | |
plt.scatter(data_np[:,0], data_np[:,1], s=4, alpha=0.3, label='True') | |
plt.scatter(samples[:,0], samples[:,1], s=4, alpha=0.6, label='Flow') | |
plt.legend(); plt.title('True vs Flow Samples') | |
plt.savefig('fff_toy_samples.png') | |
print('Saved fff_toy_samples.png') | |
except ImportError: | |
print('matplotlib not available, skipping plots.') | |
if __name__ == "__main__": | |
args = parse_args() | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment