-
-
Save francois-rozet/fd6a820e052157f8ac6e2aa39e16c1aa to your computer and use it in GitHub Desktop.
#!/usr/bin/env python | |
import math | |
import matplotlib.pyplot as plt | |
import torch | |
import torch.nn as nn | |
from sklearn.datasets import make_moons | |
from torch import Tensor | |
from tqdm import tqdm | |
from typing import * | |
from zuko.utils import odeint | |
def log_normal(x: Tensor) -> Tensor: | |
return -(x.square() + math.log(2 * math.pi)).sum(dim=-1) / 2 | |
class MLP(nn.Sequential): | |
def __init__( | |
self, | |
in_features: int, | |
out_features: int, | |
hidden_features: List[int] = [64, 64], | |
): | |
layers = [] | |
for a, b in zip( | |
(in_features, *hidden_features), | |
(*hidden_features, out_features), | |
): | |
layers.extend([nn.Linear(a, b), nn.ELU()]) | |
super().__init__(*layers[:-1]) | |
class CNF(nn.Module): | |
def __init__(self, features: int, freqs: int = 3, **kwargs): | |
super().__init__() | |
self.net = MLP(2 * freqs + features, features, **kwargs) | |
self.register_buffer('freqs', torch.arange(1, freqs + 1) * torch.pi) | |
def forward(self, t: Tensor, x: Tensor) -> Tensor: | |
t = self.freqs * t[..., None] | |
t = torch.cat((t.cos(), t.sin()), dim=-1) | |
t = t.expand(*x.shape[:-1], -1) | |
return self.net(torch.cat((t, x), dim=-1)) | |
def encode(self, x: Tensor) -> Tensor: | |
return odeint(self, x, 0.0, 1.0, phi=self.parameters()) | |
def decode(self, z: Tensor) -> Tensor: | |
return odeint(self, z, 1.0, 0.0, phi=self.parameters()) | |
def log_prob(self, x: Tensor) -> Tensor: | |
I = torch.eye(x.shape[-1], dtype=x.dtype, device=x.device) | |
I = I.expand(*x.shape, x.shape[-1]).movedim(-1, 0) | |
def augmented(t: Tensor, x: Tensor, ladj: Tensor) -> Tensor: | |
with torch.enable_grad(): | |
x = x.requires_grad_() | |
dx = self(t, x) | |
jacobian = torch.autograd.grad(dx, x, I, create_graph=True, is_grads_batched=True)[0] | |
trace = torch.einsum('i...i', jacobian) | |
return dx, trace * 1e-2 | |
ladj = torch.zeros_like(x[..., 0]) | |
z, ladj = odeint(augmented, (x, ladj), 0.0, 1.0, phi=self.parameters()) | |
return log_normal(z) + ladj * 1e2 | |
class FlowMatchingLoss(nn.Module): | |
def __init__(self, v: nn.Module): | |
super().__init__() | |
self.v = v | |
def forward(self, x: Tensor) -> Tensor: | |
t = torch.rand_like(x[..., 0, None]) | |
z = torch.randn_like(x) | |
y = (1 - t) * x + (1e-4 + (1 - 1e-4) * t) * z | |
u = (1 - 1e-4) * z - x | |
return (self.v(t.squeeze(-1), y) - u).square().mean() | |
if __name__ == '__main__': | |
flow = CNF(2, hidden_features=[64] * 3) | |
# Training | |
loss = FlowMatchingLoss(flow) | |
optimizer = torch.optim.Adam(flow.parameters(), lr=1e-3) | |
data, _ = make_moons(16384, noise=0.05) | |
data = torch.from_numpy(data).float() | |
for epoch in tqdm(range(16384), ncols=88): | |
subset = torch.randint(0, len(data), (256,)) | |
x = data[subset] | |
loss(x).backward() | |
optimizer.step() | |
optimizer.zero_grad() | |
# Sampling | |
with torch.no_grad(): | |
z = torch.randn(16384, 2) | |
x = flow.decode(z) | |
plt.figure(figsize=(4.8, 4.8), dpi=150) | |
plt.hist2d(*x.T, bins=64) | |
plt.savefig('moons_fm.pdf') | |
# Log-likelihood | |
with torch.no_grad(): | |
log_p = flow.log_prob(data[:4]) | |
print(log_p) |
Hi @jenkspt, yes this operation is indeed expensive. Instead of computing the Jacobian, it is common to use the (unbiased) Hutchinson trace estimator instead. I have not implemented this here, but I can point you to an implementation if you want.
Note that computing the Jacobian "per-pixel" is the same as computing the diagonal of the Jacobian, which would be enough to compute the trace, but I don't think there is an algorithm to do that cheaply.
To demystify odeint
for people new to ODEs like myself, I tried to implement a simple forward Euler version. It seems to generate a similar moon plot at the end but I couldn't figure out how to make it work with log_prob
.
def odeint(
f: Callable[[Tensor, Tensor], Tensor],
x: Tensor,
t0: float,
t1: float,
phi: Iterable[Tensor] = (),
dt: float = 0.01,
):
# Initialize time and state
t = torch.tensor(t0, dtype=torch.float32)
t_final = torch.tensor(t1, dtype=torch.float32)
state = x
# Calculate number of steps needed
n_steps = int(abs((t_final - t) / dt))
dt = torch.sign(t_final - t) * dt
# Integrate using forward Euler method
for t in torch.linspace(t, t_final, n_steps)[1:]:
dx = f(t, state)
state = state + dt * dx
return state
Hi @AlienKevin. If you want your odeint
to work with log_prob
, you will need to pack x
and ladj
as a single tensor representing the state of the ODE and unpack it inside the function to integrate.
For example, if you want to integrate a function
s1, s2 = x1.shape, x2.shape
n1, n2 = x1.numel(), x2.numel()
def g(t, x):
x1, x2 = x[:n1].reshape(s1), x[n1:].reshape(s2)
dx1, dx2 = f(t, x1, x2)
return torch.cat((dx1.flatten(), dx2.flatten()))
x = torch.cat((x1.flatten(), x2.flatten()))
y = odeint(g, x, ...) # instead of odeint(f, (x1, x2), ...)
Got it, thanks!
is common to use the (unbiased) Hutchinson trace estimator instead. I have not implemented this here, but I can point you to an implementation if you want
Yes I'm interested!
I'm looking at the
log_prob
function. For e.g. an image dataset this is quite expensive. Is it reasonable to treat pixels as independent predictions in this case? and only compute the jacobian per pixel?