Skip to content

Instantly share code, notes, and snippets.

@madebyollin
Last active February 18, 2025 17:01
Show Gist options
  • Save madebyollin/9ade93d41eef03034374e6a7cfd4b28e to your computer and use it in GitHub Desktop.
Save madebyollin/9ade93d41eef03034374e6a7cfd4b28e to your computer and use it in GitHub Desktop.
Single-pass Superconditioning

Single-pass Superconditioning

Motivation

Guided diffusion sampling typically uses two forward passes per step:

  1. One caption-conditional forward pass, to compute E[flow | noisy image, noise level, caption]
  2. One unconditional forward pass, to compute E[flow | noisy image, noise level]

These results are then linearly combined to form a single guided/superconditioned flow prediction.

This means we are, annoyingly, spending 2x memory and 2x flops to get 1x prediction.

If we could combine these tasks into a single forward pass, we could sample caption-superconditioned images at 1/2 the cost.

Concept

Conditioning a prediction on an independent variable is equivalent to not conditioning it at all.

This means that if we train a model to consume (noisy image, noise level, caption) and predict two outputs (flow if caption is relevant, flow if caption is irrelevant), we can treat "flow if caption is irrelevant" as our unconditional prediction and get superconditioned results from a single forward pass.

To train both outputs at once, we can make 10% of samples in each batch use a random caption, and supervise the flow-if-caption-is-irrelevant output if and only if the caption was randomized for that sample.

Execution in code

The model now has 2x output channels:

class Denoiser(nn.Module):
    def __init__(self, n_image_channels, n_embed):
        ...
        # two predictions, one for "with normal caption", one for "with random caption"
        self.end = nn.Conv2d(n_embed, n_image_channels * 2, 1)

    def forward(self, noisy_image, noise_level, caption):
        ...
        pred_cond, pred_uncond = self.end(x).chunk(2, dim=1)
        return pred_cond, pred_uncond

The sampling now retrieves conditional and unconditional samples from the same forward-pass result:

def get_guided_prediction(noisy_image, noise_level, guidance_scale=2.0):
    pred_cond, pred_uncond = model(noisy_image, noise_level, caption)
    return pred_uncond + guidance_scale * (pred_cond - pred_uncond)

The only tricky changes are in the GPU preprocessing of inputs, where we need to occasionally randomize the caption:

# pick set of ~0.1*batch items to be uncond,
# replace their caption with previous caption in the batch (i.e. an irrelevant caption)
is_uncond = th.rand(size=(batch_size,)) < 0.1
caption[is_uncond] = th.roll(caption, shifts=1, dims=0)[is_uncond]

...and in the computation of losses, where we now need to supervise both model outputs:

pred_cond, pred_uncond = model(noisy_image, noise_level, caption)
losses_cond = F.mse_loss(pred_cond, target, reduction="none").mean((1, 2, 3), keepdim=True)
losses_uncond = F.mse_loss(pred_uncond, target, reduction="none").mean((1, 2, 3), keepdim=True)
losses = is_uncond * losses_uncond + (~is_uncond) * losses_cond
loss = losses.mean()

Conclusion

By training a single model that produces both conditional and unconditional outputs in the same forward pass, we get both the efficiency of single-pass (unconditional or distilled-guidance) models and the flexibility to vary guidance_scale at inference time, so I think this approach is fairly low-downside. You should be able to extend the same approach to even support negative prompting (by adding two prompt inputs to the model), although I haven't tested that yet.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment