Guided diffusion sampling typically uses two forward passes per step:
- One caption-conditional forward pass, to compute
E[flow | noisy image, noise level, caption]
- One unconditional forward pass, to compute
E[flow | noisy image, noise level]
These results are then linearly combined to form a single guided/superconditioned flow prediction.
This means we are, annoyingly, spending 2x memory and 2x flops to get 1x prediction.
If we could combine these tasks into a single forward pass, we could sample caption-superconditioned images at 1/2 the cost.
Conditioning a prediction on an independent variable is equivalent to not conditioning it at all.
This means that if we train a model to consume (noisy image, noise level, caption)
and predict two outputs (flow if caption is relevant, flow if caption is irrelevant)
, we can treat "flow if caption is irrelevant" as our unconditional prediction and get superconditioned results from a single forward pass.
To train both outputs at once, we can make 10% of samples in each batch use a random caption, and supervise the flow-if-caption-is-irrelevant output if and only if the caption was randomized for that sample.
The model now has 2x output channels:
class Denoiser(nn.Module):
def __init__(self, n_image_channels, n_embed):
...
# two predictions, one for "with normal caption", one for "with random caption"
self.end = nn.Conv2d(n_embed, n_image_channels * 2, 1)
def forward(self, noisy_image, noise_level, caption):
...
pred_cond, pred_uncond = self.end(x).chunk(2, dim=1)
return pred_cond, pred_uncond
The sampling now retrieves conditional and unconditional samples from the same forward-pass result:
def get_guided_prediction(noisy_image, noise_level, guidance_scale=2.0):
pred_cond, pred_uncond = model(noisy_image, noise_level, caption)
return pred_uncond + guidance_scale * (pred_cond - pred_uncond)
The only tricky changes are in the GPU preprocessing of inputs, where we need to occasionally randomize the caption:
# pick set of ~0.1*batch items to be uncond,
# replace their caption with previous caption in the batch (i.e. an irrelevant caption)
is_uncond = th.rand(size=(batch_size,)) < 0.1
caption[is_uncond] = th.roll(caption, shifts=1, dims=0)[is_uncond]
...and in the computation of losses, where we now need to supervise both model outputs:
pred_cond, pred_uncond = model(noisy_image, noise_level, caption)
losses_cond = F.mse_loss(pred_cond, target, reduction="none").mean((1, 2, 3), keepdim=True)
losses_uncond = F.mse_loss(pred_uncond, target, reduction="none").mean((1, 2, 3), keepdim=True)
losses = is_uncond * losses_uncond + (~is_uncond) * losses_cond
loss = losses.mean()
By training a single model that produces both conditional and unconditional outputs in the same forward pass,
we get both the efficiency of single-pass (unconditional or distilled-guidance) models and the flexibility to vary guidance_scale
at inference time,
so I think this approach is fairly low-downside. You should be able to extend the same approach to even support negative prompting (by adding two prompt inputs to the model), although I haven't tested that yet.