Skip to content

Instantly share code, notes, and snippets.

Forked from mrsteyk/
Last active April 2, 2024 13:50
Show Gist options
  • Save madebyollin/865fa6a18d9099351ddbdfbe7299ccbf to your computer and use it in GitHub Desktop.
Save madebyollin/865fa6a18d9099351ddbdfbe7299ccbf to your computer and use it in GitHub Desktop.

Consistency Decoder PyTorch Model Code

Cleaned up version of, which is in turn based on the public_diff_vae.ConvUNetVAE from

Example Usage

Install the consistency decoder code (for the inference logic) and download the extracted weights:

pip install -q git+
git clone

Then, run the standard sample code (but replace the jitted checkpoint with a ConvUNetVAE instance):

import torch
from diffusers import StableDiffusionPipeline
from consistencydecoder import ConsistencyDecoder, save_image, load_image

from conv_unet_vae import ConvUNetVAE, rename_state_dict
from safetensors.torch import load_file as stl

# encode with stable diffusion vae
pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16

# construct original decoder with jitted model
decoder_consistency = ConsistencyDecoder(device="cuda:0")

# construct UNet code, overwrite the decoder with conv_unet_vae
model = ConvUNetVAE()
model = model.cuda()
decoder_consistency.ckpt = model

image = load_image("test_dog_image.jpg", size=(256, 256), center_crop=True)
latent = pipe.vae.encode(image.half().cuda()).latent_dist.sample()

# decode with gan
sample_gan = pipe.vae.decode(latent).sample.detach()
save_image(sample_gan, "gan.png")

# decode with conv_unet_vae
sample_consistency = decoder_consistency(latent)
save_image(sample_consistency, "con.png")

The result should be a faithful reconstruction of the original image:


#!/usr/bin/env python3
Cleaned up reimplementation of public_diff_vae.ConvUNetVAE,
thanks to
import torch
import torch.nn.functional as F
import torch.nn as nn
class TimestepEmbedding(nn.Module):
def __init__(self, n_time=1024, n_emb=320, n_out=1280) -> None:
self.emb = nn.Embedding(n_time, n_emb)
self.f_1 = nn.Linear(n_emb, n_out)
self.f_2 = nn.Linear(n_out, n_out)
def forward(self, x) -> torch.Tensor:
x = self.emb(x)
x = self.f_1(x)
x = F.silu(x)
return self.f_2(x)
class ImageEmbedding(nn.Module):
def __init__(self, in_channels=7, out_channels=320) -> None:
self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
def forward(self, x) -> torch.Tensor:
return self.f(x)
class ImageUnembedding(nn.Module):
def __init__(self, in_channels=320, out_channels=6) -> None:
super().__init__() = nn.GroupNorm(32, in_channels)
self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
def forward(self, x) -> torch.Tensor:
return self.f(F.silu(
class ConvResblock(nn.Module):
def __init__(self, in_features=320, out_features=320) -> None:
self.f_t = nn.Linear(1280, out_features * 2)
self.gn_1 = nn.GroupNorm(32, in_features)
self.f_1 = nn.Conv2d(in_features, out_features, kernel_size=3, padding=1)
self.gn_2 = nn.GroupNorm(32, out_features)
self.f_2 = nn.Conv2d(out_features, out_features, kernel_size=3, padding=1)
skip_conv = in_features != out_features
self.f_s = (
nn.Conv2d(in_features, out_features, kernel_size=1, padding=0)
if skip_conv
else nn.Identity()
def forward(self, x, t):
x_skip = x
t = self.f_t(F.silu(t))
t = t.chunk(2, dim=1)
t_1 = t[0].unsqueeze(dim=2).unsqueeze(dim=3) + 1
t_2 = t[1].unsqueeze(dim=2).unsqueeze(dim=3)
gn_1 = F.silu(self.gn_1(x))
f_1 = self.f_1(gn_1)
gn_2 = self.gn_2(f_1)
return self.f_s(x_skip) + self.f_2(F.silu(gn_2 * t_1 + t_2))
# Also ConvResblock
class Downsample(nn.Module):
def __init__(self, in_channels=320) -> None:
self.f_t = nn.Linear(1280, in_channels * 2)
self.gn_1 = nn.GroupNorm(32, in_channels)
self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.gn_2 = nn.GroupNorm(32, in_channels)
self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
def forward(self, x, t) -> torch.Tensor:
x_skip = x
t = self.f_t(F.silu(t))
t_1, t_2 = t.chunk(2, dim=1)
t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1
t_2 = t_2.unsqueeze(2).unsqueeze(3)
gn_1 = F.silu(self.gn_1(x))
avg_pool2d = F.avg_pool2d(gn_1, kernel_size=(2, 2), stride=None)
f_1 = self.f_1(avg_pool2d)
gn_2 = self.gn_2(f_1)
f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2)))
return f_2 + F.avg_pool2d(x_skip, kernel_size=(2, 2), stride=None)
# Also ConvResblock
class Upsample(nn.Module):
def __init__(self, in_channels=1024) -> None:
self.f_t = nn.Linear(1280, in_channels * 2)
self.gn_1 = nn.GroupNorm(32, in_channels)
self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.gn_2 = nn.GroupNorm(32, in_channels)
self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
def forward(self, x, t) -> torch.Tensor:
x_skip = x
t = self.f_t(F.silu(t))
t_1, t_2 = t.chunk(2, dim=1)
t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1
t_2 = t_2.unsqueeze(2).unsqueeze(3)
gn_1 = F.silu(self.gn_1(x))
upsample = F.upsample_nearest(gn_1, scale_factor=2)
f_1 = self.f_1(upsample)
gn_2 = self.gn_2(f_1)
f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2)))
return f_2 + F.upsample_nearest(x_skip, scale_factor=2)
class ConvUNetVAE(nn.Module):
def __init__(self) -> None:
self.embed_image = ImageEmbedding()
self.embed_time = TimestepEmbedding()
down_0 = nn.ModuleList(
ConvResblock(320, 320),
ConvResblock(320, 320),
ConvResblock(320, 320),
down_1 = nn.ModuleList(
ConvResblock(320, 640),
ConvResblock(640, 640),
ConvResblock(640, 640),
down_2 = nn.ModuleList(
ConvResblock(640, 1024),
ConvResblock(1024, 1024),
ConvResblock(1024, 1024),
down_3 = nn.ModuleList(
ConvResblock(1024, 1024),
ConvResblock(1024, 1024),
ConvResblock(1024, 1024),
self.down = nn.ModuleList(
self.mid = nn.ModuleList(
ConvResblock(1024, 1024),
ConvResblock(1024, 1024),
up_3 = nn.ModuleList(
ConvResblock(1024 * 2, 1024),
ConvResblock(1024 * 2, 1024),
ConvResblock(1024 * 2, 1024),
ConvResblock(1024 * 2, 1024),
up_2 = nn.ModuleList(
ConvResblock(1024 * 2, 1024),
ConvResblock(1024 * 2, 1024),
ConvResblock(1024 * 2, 1024),
ConvResblock(1024 + 640, 1024),
up_1 = nn.ModuleList(
ConvResblock(1024 + 640, 640),
ConvResblock(640 * 2, 640),
ConvResblock(640 * 2, 640),
ConvResblock(320 + 640, 640),
up_0 = nn.ModuleList(
ConvResblock(320 + 640, 320),
ConvResblock(320 * 2, 320),
ConvResblock(320 * 2, 320),
ConvResblock(320 * 2, 320),
self.up = nn.ModuleList(
self.output = ImageUnembedding()
def forward(self, x, t, features) -> torch.Tensor:
x =[x, F.upsample_nearest(features, scale_factor=8)], dim=1)
t = self.embed_time(t)
x = self.embed_image(x)
skips = [x]
for down in self.down:
for block in down:
x = block(x, t)
for i in range(2):
x = self.mid[i](x, t)
for up in self.up[::-1]:
for block in up:
if isinstance(block, ConvResblock):
x = torch.concat([x, skips.pop()], dim=1)
x = block(x, t)
return self.output(x)
def rename_state_dict_key(k):
k = k.replace("blocks.", "")
for i in range(5):
k = k.replace(f"down_{i}_", f"down.{i}.")
k = k.replace(f"conv_{i}.", f"{i}.")
k = k.replace(f"up_{i}_", f"up.{i}.")
k = k.replace(f"mid_{i}", f"mid.{i}")
k = k.replace("upsamp.", "4.")
k = k.replace("downsamp.", "3.")
k = k.replace("f_t.w", "f_t.weight").replace("f_t.b", "f_t.bias")
k = k.replace("f_1.w", "f_1.weight").replace("f_1.b", "f_1.bias")
k = k.replace("f_2.w", "f_2.weight").replace("f_2.b", "f_2.bias")
k = k.replace("f_s.w", "f_s.weight").replace("f_s.b", "f_s.bias")
k = k.replace("f.w", "f.weight").replace("f.b", "f.bias")
k = k.replace("gn_1.g", "gn_1.weight").replace("gn_1.b", "gn_1.bias")
k = k.replace("gn_2.g", "gn_2.weight").replace("gn_2.b", "gn_2.bias")
k = k.replace("gn.g", "gn.weight").replace("gn.b", "gn.bias")
return k
def rename_state_dict(sd, embedding):
sd = {rename_state_dict_key(k): v for k, v in sd.items()}
sd["embed_time.emb.weight"] = embedding["weight"]
return sd
if __name__ == "__main__":
model = ConvUNetVAE()
import safetensors.torch
cd_orig = safetensors.torch.load_file("consistency_decoder.safetensors")
embedding = safetensors.torch.load_file("embedding.safetensors")
print(model.load_state_dict(rename_state_dict(cd_orig, embedding)))
Copy link

@city96 Yeah, your interposer is difficult to beat 😅 switching to diffusion for the interposer itself might help a bit (instead of L1 / perceptual losses), but there will still inevitably be some destruction of info if we do SDXL Latents->Ideal Interposer->SD Latents->CD, since SDXL's latent format can express some information which the SD latent format can't.

For super high quality SDXL decoding, I expect fine-tuning the CD model would ultimately be the best option.

Copy link

city96 commented Nov 11, 2023


switching to diffusion for the interposer itself might help a bit

Not sure I follow. Do you mean doing something similar to what CD does (i.e. a multi-step process)? I fear that it might get too slow if I do something like that. And if it isn't really faster than doing a VAE decode->encode between the two models then the whole thing becomes a bit useless lol.

there will still inevitably be some destruction of info if we do SDXL Latents->Ideal Interposer->SD Latents->CD

Yup. So even if the XL latent encoded a letter on the image correctly, that'd get lost on the way and CD would just make up a similar looking letter from the data that's there, since it was only ever trained to work with the KL-F8 encoder. (Reverse also applies, hence why the v1->xl interposer has higher loss. At that point you're asking the tiny 5MB model to not only match the format, but to make up fake details...)

For super high quality SDXL decoding, I expect fine-tuning the CD model would ultimately be the best option.

Agreed. Hope they end up releasing the training code. Would be interesting to see CD finetuned for specialized usecases as well (realism VS art, etc).

Copy link

@city96 Yeah, I was imagining a multistep interposer (small diffusion model, small # sampling steps). But I agree - it's easier to just roundtrip to pixel space at that point 😆

Copy link

Thank you for your excellent work. May I ask how to train a ConsistencyDecoder from scratch? OpenAI's repository does not provide a complete training process. Could you offer some suggestions?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment