-
-
Save francois-rozet/fd6a820e052157f8ac6e2aa39e16c1aa to your computer and use it in GitHub Desktop.
#!/usr/bin/env python | |
import math | |
import matplotlib.pyplot as plt | |
import torch | |
import torch.nn as nn | |
from sklearn.datasets import make_moons | |
from torch import Tensor | |
from tqdm import tqdm | |
from typing import * | |
from zuko.utils import odeint | |
def log_normal(x: Tensor) -> Tensor: | |
return -(x.square() + math.log(2 * math.pi)).sum(dim=-1) / 2 | |
class MLP(nn.Sequential): | |
def __init__( | |
self, | |
in_features: int, | |
out_features: int, | |
hidden_features: List[int] = [64, 64], | |
): | |
layers = [] | |
for a, b in zip( | |
(in_features, *hidden_features), | |
(*hidden_features, out_features), | |
): | |
layers.extend([nn.Linear(a, b), nn.ELU()]) | |
super().__init__(*layers[:-1]) | |
class CNF(nn.Module): | |
def __init__(self, features: int, freqs: int = 3, **kwargs): | |
super().__init__() | |
self.net = MLP(2 * freqs + features, features, **kwargs) | |
self.register_buffer('freqs', torch.arange(1, freqs + 1) * torch.pi) | |
def forward(self, t: Tensor, x: Tensor) -> Tensor: | |
t = self.freqs * t[..., None] | |
t = torch.cat((t.cos(), t.sin()), dim=-1) | |
t = t.expand(*x.shape[:-1], -1) | |
return self.net(torch.cat((t, x), dim=-1)) | |
def encode(self, x: Tensor) -> Tensor: | |
return odeint(self, x, 0.0, 1.0, phi=self.parameters()) | |
def decode(self, z: Tensor) -> Tensor: | |
return odeint(self, z, 1.0, 0.0, phi=self.parameters()) | |
def log_prob(self, x: Tensor) -> Tensor: | |
I = torch.eye(x.shape[-1], dtype=x.dtype, device=x.device) | |
I = I.expand(*x.shape, x.shape[-1]).movedim(-1, 0) | |
def augmented(t: Tensor, x: Tensor, ladj: Tensor) -> Tensor: | |
with torch.enable_grad(): | |
x = x.requires_grad_() | |
dx = self(t, x) | |
jacobian = torch.autograd.grad(dx, x, I, create_graph=True, is_grads_batched=True)[0] | |
trace = torch.einsum('i...i', jacobian) | |
return dx, trace * 1e-2 | |
ladj = torch.zeros_like(x[..., 0]) | |
z, ladj = odeint(augmented, (x, ladj), 0.0, 1.0, phi=self.parameters()) | |
return log_normal(z) + ladj * 1e2 | |
class FlowMatchingLoss(nn.Module): | |
def __init__(self, v: nn.Module): | |
super().__init__() | |
self.v = v | |
def forward(self, x: Tensor) -> Tensor: | |
t = torch.rand_like(x[..., 0, None]) | |
z = torch.randn_like(x) | |
y = (1 - t) * x + (1e-4 + (1 - 1e-4) * t) * z | |
u = (1 - 1e-4) * z - x | |
return (self.v(t.squeeze(-1), y) - u).square().mean() | |
if __name__ == '__main__': | |
flow = CNF(2, hidden_features=[64] * 3) | |
# Training | |
loss = FlowMatchingLoss(flow) | |
optimizer = torch.optim.Adam(flow.parameters(), lr=1e-3) | |
data, _ = make_moons(16384, noise=0.05) | |
data = torch.from_numpy(data).float() | |
for epoch in tqdm(range(16384), ncols=88): | |
subset = torch.randint(0, len(data), (256,)) | |
x = data[subset] | |
loss(x).backward() | |
optimizer.step() | |
optimizer.zero_grad() | |
# Sampling | |
with torch.no_grad(): | |
z = torch.randn(16384, 2) | |
x = flow.decode(z) | |
plt.figure(figsize=(4.8, 4.8), dpi=150) | |
plt.hist2d(*x.T, bins=64) | |
plt.savefig('moons_fm.pdf') | |
# Log-likelihood | |
with torch.no_grad(): | |
log_p = flow.log_prob(data[:4]) | |
print(log_p) |
@francois-rozet Sorry for the slow reply
It's called v-objective too. It was first used by nvidia for distilling I think.
Katherine used it here https://github.com/crowsonkb/v-diffusion-pytorch
It's used in some of the v2 stable diffusion models https://huggingface.co/stabilityai/stable-diffusion-2
@samedii Thanks for the references! It is indeed similar in that a difference between
Hi! I copy-paste your code and got a bug saying TypeError: grad() got an unexpected keyword argument 'is_grads_batched'
.
I check the pytorch torch.autograd.grad
function, it does contains the is_grads_batched
parameter. I am confused. Not sure what when run.
File "a.py", line 122, in <module>
log_p = flow.log_prob(data[:4])
File "a.py", line 74, in log_prob
z, ladj = odeint(augmented, (x, ladj), 0.0, 1.0, phi=self.parameters())
File "/opt/anaconda3/envs/antibody/lib/python3.8/site-packages/zuko/utils.py", line 314, in odeint
return tuple(unpack(AdaptiveCheckpointAdjoint.apply(g, x, t0, t1, *phi)))
File "/opt/anaconda3/envs/antibody/lib/python3.8/site-packages/zuko/utils.py", line 415, in forward
y, error = dopri45(f, x, t, dt, error=True)
File "/opt/anaconda3/envs/antibody/lib/python3.8/site-packages/zuko/utils.py", line 330, in dopri45
k1 = dt * f(t, x)
File "/opt/anaconda3/envs/antibody/lib/python3.8/site-packages/zuko/utils.py", line 304, in <lambda>
g = lambda t, x: pack(f(t, *unpack(x)))
File "a.py", line 68, in augmented
jacobian = torch.autograd.grad(dx, x, I, is_grads_batched=True, create_graph=True)[0]
TypeError: grad() got an unexpected keyword argument 'is_grads_batched'
Hi! I copy-paste your code and got a bug saying
TypeError: grad() got an unexpected keyword argument 'is_grads_batched'
. I check the pytorchtorch.autograd.grad
function, it does contains theis_grads_batched
parameter.
Hello @pengzhangzhi, the is_grads_batched
option is available since PyTorch 1.11. You probably use an older version.
Hi! I copy-paste your code and got a bug saying
TypeError: grad() got an unexpected keyword argument 'is_grads_batched'
. I check the pytorchtorch.autograd.grad
function, it does contains theis_grads_batched
parameter.Hello @pengzhangzhi, the
is_grads_batched
option is available since PyTorch 1.11. You probably use an older version.
The only way to fix this bug is to upgrade the torch version right? Thanks! You are very nice!
@francois-rozet Thank you so much for making this code! Just to make sure, is this an implementation of the work "Flow Matching for Generative Modeling (https://arxiv.org/abs/2210.02747)" by Lipman et al.? If so, may I ask which loss objective in the work is your training objective based on?
I am glad you like it @hamrel-cxu! It is the optimal transport (OT) flow matching loss. Note that the 0 and 1 extremities of the time are reversed here.
Really cool! Any particular reason for inverting the time extremities?
Hello @DebajyotiS, thanks! In score-based generative modeling, it is standard to set
@francois-rozet Thank you so much for making this code! I have a question about the code.
self.register_buffer('frequencies', 2 ** torch.arange(frequencies) * torch.pi)
t = self.frequencies * t[..., None]
t = torch.cat((t.cos(), t.sin()), dim=-1)
Through this function, t is changed. Could you please explain the reason behind this?
Hello @fd873630,
Thank you for the fantastic breakdown of the code! The code helped me a lot in understanding the equations of the paper.
If I have to run this on a GPU, do you have any suggestions on how I can change the Zuko's odeint function to torchdyn? Mainly because I found zuko's odeint to be slow, but maybe because of the nature of the solution, it will take equal time if I replace it with Torchdyn, do you have any ideas around this?
Hello @shivammehta25, thanks! I never tried to use torchdyn
, mainly because of the lack of documentation. I did try with torchdiffeq's odeint_adjoint
, but it was always (1.5-2x) slower than Zuko's.
# Encode
z = torchdiffeq.odeint_adjoint(flow, x, torch.tensor((0.0, 1.0)))[-1]
# Decode
x = torchdiffeq.odeint_adjoint(flow, z, torch.tensor((1.0, 0.0)))[-1]
Note that all adaptive ODE integrators rely on CPU synchronization, so this might be a bottleneck when solving on GPU. Also, the smoother the solution, the faster the integrator, so don't be afraid to train your network for a LONG time (continue even if the loss seems to "have converged"), and use learning rate scheduling. I usually use linear scheduling.
Finally, score-based generative modeling (which flow-matching is a special case of) is slow by design. Sampling requires a lot of network evaluations, and there's not much you can do about it.
Awesome! Thanks :)
adaptive ODE integrators rely on CPU synchronization
Interesting! Would you mind elaborating on this? Why would this be the case? I am sorry for spamming here, otherwise, I would reach out to you by email, if that is fine, or it might be useful for other people as well.
score-based generative modelling (which flow-matching is a special case of)
I thought it was reverse that flow matching is the umbrella framework and score-based is one of the special case.
Adaptive ODE solvers modify their integration step size according to an estimation of the integration error. If the error is too large, the step is rejected and the step size is reduced. The "if" can only be evaluated on CPU, and hence requires CPU-GPU synchronization.
I thought it was reverse that flow matching is the umbrella framework and score-based is one of the special case.
You can view this either way. The main difference is that flow-matching approximates an ODE while score-matching approximates an SDE.
Thanks for the nice implementation! If I understand correctly, in line 88-89, you implement conditional flow matching loss (CMF) based on Equation 23 in the flow matching paper (https://arxiv.org/pdf/2210.02747.pdf). However, shouldn't it be as following?
y = (1 - (1 - 1e-4) * t) * z + t * x
u = x - (1 - 1e-4) * z
However, if I change the code, the CFM model won't work. Could you please help me with that. Thanks a lot!
Hello @yuyangw 👋 As mentioned earlier in the thread, the 0 and 1 extremities of the time are reversed in this implementation, which is why the loss is slightly different. If you want to change the loss, you also have to switch the initial and final times (odeint
calls (in encode
, decode
and log_prob
).
Hello @yuyangw 👋 As mentioned earlier in the thread, the 0 and 1 extremities of the time are reversed in this implementation, which is why the loss is slightly different. If you want to change the loss, you also have to switch the initial and final times (0↔1) of the
odeint
calls (inencode
,decode
andlog_prob
).
Hi @francois-rozet, thank you so much for the reply!
Hi @francois-rozet 👋, could you give a bit more details about the probability calculation? In particular where does the added 1e2 (1e-2) come from?
Hello @radiradev, as explained in the first comment,
Adaptive ODE solvers choose their step size according to an estimation of the integration error. For the trace-augmented ODE,
odeint
over estimates the integration error because the trace has large(r) absolute values, which leads to small step sizes. To mitigate this without significant loss of accuracy, I multiply the trace by a factor$10^{-2}$ .
To paraphrase, I don't want the computation of log-absolute-determinant of the Jacobian (ladj
) to influence the step size of the solver. But because the trace has high magnitude compared to the derivative (dx
), it does influence it (and makes it much slower) in practice. To mitigate this, I multiply trace
by a factor ladj
by the inverse factor
To paraphrase, I don't want the computation of log-absolute-determinant of the Jacobian (
ladj
) to influence the step size of the solver. But because the trace has high magnitude compared to the derivative (dx
), it does influence it (and makes it much slower) in practice. To mitigate this, I multiplytrace
by a factor 10−2, and at the end multiply theladj
by the inverse factor 102.
Hi @francois-rozet, should I change this factor when dealing with other data (e.g. image embedding), or keep it the same ?
Hello @thangld201, the best would be to try different values for the factor (basically its a tradeoff between log-prob accuracy and efficiency) and pick what suits your needs. Note that this code expects x
to be a vector or a batch of vectors. If x
has the shape of an image it will likely not work.
@francois-rozet Thanks for your answer. So if the factor is lower (e.g. 1e-6), it gets less accurate but faster ?
Exactly, but potentially much less accurate, while being marginally faster. That's why you should try a few values (with the same input, to compare the results).
For decoding - I don't see anything that necessitates z being from a normal distribution. Does this mean z can be sampled from any probability distribution?
@jenkspt I would think so, I am aware of at least one study (in the context of data unfolding in High Energy Physics) that does data to data with this formulation. https://arxiv.org/abs/2311.17175
I have to think a bit deeply if that makes sense, though. (Results look good nonetheless)
@jenkspt As long as the distribution of
I'm looking at the log_prob
function. For e.g. an image dataset this is quite expensive. Is it reasonable to treat pixels as independent predictions in this case? and only compute the jacobian per pixel?
Hi @jenkspt, yes this operation is indeed expensive. Instead of computing the Jacobian, it is common to use the (unbiased) Hutchinson trace estimator instead. I have not implemented this here, but I can point you to an implementation if you want.
Note that computing the Jacobian "per-pixel" is the same as computing the diagonal of the Jacobian, which would be enough to compute the trace, but I don't think there is an algorithm to do that cheaply.
Hello @samedii, I am not sure to understand what you mean by "velocity objective"?