Skip to content

Instantly share code, notes, and snippets.

@francois-rozet
Last active October 19, 2025 22:48
Show Gist options
  • Save francois-rozet/fd6a820e052157f8ac6e2aa39e16c1aa to your computer and use it in GitHub Desktop.
Save francois-rozet/fd6a820e052157f8ac6e2aa39e16c1aa to your computer and use it in GitHub Desktop.
Flow Matching in 100 LOC
#!/usr/bin/env python
import math
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from sklearn.datasets import make_moons
from torch import Tensor
from tqdm import tqdm
from typing import *
from zuko.utils import odeint
def log_normal(x: Tensor) -> Tensor:
return -(x.square() + math.log(2 * math.pi)).sum(dim=-1) / 2
class MLP(nn.Sequential):
def __init__(
self,
in_features: int,
out_features: int,
hidden_features: List[int] = [64, 64],
):
layers = []
for a, b in zip(
(in_features, *hidden_features),
(*hidden_features, out_features),
):
layers.extend([nn.Linear(a, b), nn.ELU()])
super().__init__(*layers[:-1])
class CNF(nn.Module):
def __init__(self, features: int, freqs: int = 3, **kwargs):
super().__init__()
self.net = MLP(2 * freqs + features, features, **kwargs)
self.register_buffer('freqs', torch.arange(1, freqs + 1) * torch.pi)
def forward(self, t: Tensor, x: Tensor) -> Tensor:
t = self.freqs * t[..., None]
t = torch.cat((t.cos(), t.sin()), dim=-1)
t = t.expand(*x.shape[:-1], -1)
return self.net(torch.cat((t, x), dim=-1))
def encode(self, x: Tensor) -> Tensor:
return odeint(self, x, 0.0, 1.0, phi=self.parameters())
def decode(self, z: Tensor) -> Tensor:
return odeint(self, z, 1.0, 0.0, phi=self.parameters())
def log_prob(self, x: Tensor) -> Tensor:
I = torch.eye(x.shape[-1], dtype=x.dtype, device=x.device)
I = I.expand(*x.shape, x.shape[-1]).movedim(-1, 0)
def augmented(t: Tensor, x: Tensor, ladj: Tensor) -> Tensor:
with torch.enable_grad():
x = x.requires_grad_()
dx = self(t, x)
jacobian = torch.autograd.grad(dx, x, I, create_graph=True, is_grads_batched=True)[0]
trace = torch.einsum('i...i', jacobian)
return dx, trace * 1e-2
ladj = torch.zeros_like(x[..., 0])
z, ladj = odeint(augmented, (x, ladj), 0.0, 1.0, phi=self.parameters())
return log_normal(z) + ladj * 1e2
class FlowMatchingLoss(nn.Module):
def __init__(self, v: nn.Module):
super().__init__()
self.v = v
def forward(self, x: Tensor) -> Tensor:
t = torch.rand_like(x[..., 0, None])
z = torch.randn_like(x)
y = (1 - t) * x + (1e-4 + (1 - 1e-4) * t) * z
u = (1 - 1e-4) * z - x
return (self.v(t.squeeze(-1), y) - u).square().mean()
if __name__ == '__main__':
flow = CNF(2, hidden_features=[64] * 3)
# Training
loss = FlowMatchingLoss(flow)
optimizer = torch.optim.Adam(flow.parameters(), lr=1e-3)
data, _ = make_moons(16384, noise=0.05)
data = torch.from_numpy(data).float()
for epoch in tqdm(range(16384), ncols=88):
subset = torch.randint(0, len(data), (256,))
x = data[subset]
loss(x).backward()
optimizer.step()
optimizer.zero_grad()
# Sampling
with torch.no_grad():
z = torch.randn(16384, 2)
x = flow.decode(z)
plt.figure(figsize=(4.8, 4.8), dpi=150)
plt.hist2d(*x.T, bins=64)
plt.savefig('moons_fm.pdf')
# Log-likelihood
with torch.no_grad():
log_p = flow.log_prob(data[:4])
print(log_p)
@francois-rozet
Copy link
Author

@jenkspt As long as the distribution of $z$ is the same during training and sampling, I think it should work.

@jenkspt
Copy link

jenkspt commented Oct 26, 2024

I'm looking at the log_prob function. For e.g. an image dataset this is quite expensive. Is it reasonable to treat pixels as independent predictions in this case? and only compute the jacobian per pixel?

@francois-rozet
Copy link
Author

Hi @jenkspt, yes this operation is indeed expensive. Instead of computing the Jacobian, it is common to use the (unbiased) Hutchinson trace estimator instead. I have not implemented this here, but I can point you to an implementation if you want.

Note that computing the Jacobian "per-pixel" is the same as computing the diagonal of the Jacobian, which would be enough to compute the trace, but I don't think there is an algorithm to do that cheaply.

@AlienKevin
Copy link

To demystify odeint for people new to ODEs like myself, I tried to implement a simple forward Euler version. It seems to generate a similar moon plot at the end but I couldn't figure out how to make it work with log_prob.

def odeint(
    f: Callable[[Tensor, Tensor], Tensor],
    x: Tensor,
    t0: float,
    t1: float,
    phi: Iterable[Tensor] = (),
    dt: float = 0.01,
):
    # Initialize time and state
    t = torch.tensor(t0, dtype=torch.float32)
    t_final = torch.tensor(t1, dtype=torch.float32)
    state = x

    # Calculate number of steps needed
    n_steps = int(abs((t_final - t) / dt))
    dt = torch.sign(t_final - t) * dt
    
    # Integrate using forward Euler method
    for t in torch.linspace(t, t_final, n_steps)[1:]:
        dx = f(t, state)
        state = state + dt * dx

    return state

@francois-rozet
Copy link
Author

francois-rozet commented Dec 2, 2024

Hi @AlienKevin. If you want your odeint to work with log_prob, you will need to pack x and ladj as a single tensor representing the state of the ODE and unpack it inside the function to integrate.

For example, if you want to integrate a function $f$ that takes $x_1$ and $x_2$ and returns $\dot{x}_1$ and $\dot{x}_2$:

s1, s2 = x1.shape, x2.shape
n1, n2 = x1.numel(), x2.numel()

def g(t, x):
    x1, x2 = x[:n1].reshape(s1), x[n1:].reshape(s2)
    dx1, dx2 = f(t, x1, x2)
    return torch.cat((dx1.flatten(), dx2.flatten()))

x = torch.cat((x1.flatten(), x2.flatten()))

y = odeint(g, x, ...)  # instead of odeint(f, (x1, x2), ...) 

@AlienKevin
Copy link

Got it, thanks!

@jenkspt
Copy link

jenkspt commented Feb 17, 2025

is common to use the (unbiased) Hutchinson trace estimator instead. I have not implemented this here, but I can point you to an implementation if you want

Yes I'm interested!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment