Last active
March 13, 2025 06:17
-
-
Save francois-rozet/fd6a820e052157f8ac6e2aa39e16c1aa to your computer and use it in GitHub Desktop.
Flow Matching in 100 LOC
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import math | |
import matplotlib.pyplot as plt | |
import torch | |
import torch.nn as nn | |
from sklearn.datasets import make_moons | |
from torch import Tensor | |
from tqdm import tqdm | |
from typing import * | |
from zuko.utils import odeint | |
def log_normal(x: Tensor) -> Tensor: | |
return -(x.square() + math.log(2 * math.pi)).sum(dim=-1) / 2 | |
class MLP(nn.Sequential): | |
def __init__( | |
self, | |
in_features: int, | |
out_features: int, | |
hidden_features: List[int] = [64, 64], | |
): | |
layers = [] | |
for a, b in zip( | |
(in_features, *hidden_features), | |
(*hidden_features, out_features), | |
): | |
layers.extend([nn.Linear(a, b), nn.ELU()]) | |
super().__init__(*layers[:-1]) | |
class CNF(nn.Module): | |
def __init__(self, features: int, freqs: int = 3, **kwargs): | |
super().__init__() | |
self.net = MLP(2 * freqs + features, features, **kwargs) | |
self.register_buffer('freqs', torch.arange(1, freqs + 1) * torch.pi) | |
def forward(self, t: Tensor, x: Tensor) -> Tensor: | |
t = self.freqs * t[..., None] | |
t = torch.cat((t.cos(), t.sin()), dim=-1) | |
t = t.expand(*x.shape[:-1], -1) | |
return self.net(torch.cat((t, x), dim=-1)) | |
def encode(self, x: Tensor) -> Tensor: | |
return odeint(self, x, 0.0, 1.0, phi=self.parameters()) | |
def decode(self, z: Tensor) -> Tensor: | |
return odeint(self, z, 1.0, 0.0, phi=self.parameters()) | |
def log_prob(self, x: Tensor) -> Tensor: | |
I = torch.eye(x.shape[-1], dtype=x.dtype, device=x.device) | |
I = I.expand(*x.shape, x.shape[-1]).movedim(-1, 0) | |
def augmented(t: Tensor, x: Tensor, ladj: Tensor) -> Tensor: | |
with torch.enable_grad(): | |
x = x.requires_grad_() | |
dx = self(t, x) | |
jacobian = torch.autograd.grad(dx, x, I, create_graph=True, is_grads_batched=True)[0] | |
trace = torch.einsum('i...i', jacobian) | |
return dx, trace * 1e-2 | |
ladj = torch.zeros_like(x[..., 0]) | |
z, ladj = odeint(augmented, (x, ladj), 0.0, 1.0, phi=self.parameters()) | |
return log_normal(z) + ladj * 1e2 | |
class FlowMatchingLoss(nn.Module): | |
def __init__(self, v: nn.Module): | |
super().__init__() | |
self.v = v | |
def forward(self, x: Tensor) -> Tensor: | |
t = torch.rand_like(x[..., 0, None]) | |
z = torch.randn_like(x) | |
y = (1 - t) * x + (1e-4 + (1 - 1e-4) * t) * z | |
u = (1 - 1e-4) * z - x | |
return (self.v(t.squeeze(-1), y) - u).square().mean() | |
if __name__ == '__main__': | |
flow = CNF(2, hidden_features=[64] * 3) | |
# Training | |
loss = FlowMatchingLoss(flow) | |
optimizer = torch.optim.Adam(flow.parameters(), lr=1e-3) | |
data, _ = make_moons(16384, noise=0.05) | |
data = torch.from_numpy(data).float() | |
for epoch in tqdm(range(16384), ncols=88): | |
subset = torch.randint(0, len(data), (256,)) | |
x = data[subset] | |
loss(x).backward() | |
optimizer.step() | |
optimizer.zero_grad() | |
# Sampling | |
with torch.no_grad(): | |
z = torch.randn(16384, 2) | |
x = flow.decode(z) | |
plt.figure(figsize=(4.8, 4.8), dpi=150) | |
plt.hist2d(*x.T, bins=64) | |
plt.savefig('moons_fm.pdf') | |
# Log-likelihood | |
with torch.no_grad(): | |
log_p = flow.log_prob(data[:4]) | |
print(log_p) |
To demystify odeint
for people new to ODEs like myself, I tried to implement a simple forward Euler version. It seems to generate a similar moon plot at the end but I couldn't figure out how to make it work with log_prob
.
def odeint(
f: Callable[[Tensor, Tensor], Tensor],
x: Tensor,
t0: float,
t1: float,
phi: Iterable[Tensor] = (),
dt: float = 0.01,
):
# Initialize time and state
t = torch.tensor(t0, dtype=torch.float32)
t_final = torch.tensor(t1, dtype=torch.float32)
state = x
# Calculate number of steps needed
n_steps = int(abs((t_final - t) / dt))
dt = torch.sign(t_final - t) * dt
# Integrate using forward Euler method
for t in torch.linspace(t, t_final, n_steps)[1:]:
dx = f(t, state)
state = state + dt * dx
return state
Hi @AlienKevin. If you want your odeint
to work with log_prob
, you will need to pack x
and ladj
as a single tensor representing the state of the ODE and unpack it inside the function to integrate.
For example, if you want to integrate a function
s1, s2 = x1.shape, x2.shape
n1, n2 = x1.numel(), x2.numel()
def g(t, x):
x1, x2 = x[:n1].reshape(s1), x[n1:].reshape(s2)
dx1, dx2 = f(t, x1, x2)
return torch.cat((dx1.flatten(), dx2.flatten()))
x = torch.cat((x1.flatten(), x2.flatten()))
y = odeint(g, x, ...) # instead of odeint(f, (x1, x2), ...)
Got it, thanks!
is common to use the (unbiased) Hutchinson trace estimator instead. I have not implemented this here, but I can point you to an implementation if you want
Yes I'm interested!
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi @jenkspt, yes this operation is indeed expensive. Instead of computing the Jacobian, it is common to use the (unbiased) Hutchinson trace estimator instead. I have not implemented this here, but I can point you to an implementation if you want.
Note that computing the Jacobian "per-pixel" is the same as computing the diagonal of the Jacobian, which would be enough to compute the trace, but I don't think there is an algorithm to do that cheaply.