|
import torch |
|
|
|
class TimestepEmbedding(torch.nn.Module): |
|
def __init__(self, n_time=1024, n_emb=320, n_out=1280) -> None: |
|
super().__init__() |
|
self.emb = torch.nn.Embedding(n_time, n_emb) |
|
self.f_1 = torch.nn.Linear(n_emb, n_out) |
|
# self.act = torch.nn.SiLU() |
|
self.f_2 = torch.nn.Linear(n_out, n_out) |
|
|
|
def forward(self, x) -> torch.Tensor: |
|
x = self.emb(x) |
|
x = self.f_1(x) |
|
x = torch.nn.functional.silu(x) |
|
return self.f_2(x) |
|
|
|
|
|
class ImageEmbedding(torch.nn.Module): |
|
def __init__(self, in_channels=7, out_channels=320) -> None: |
|
super().__init__() |
|
self.f = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) |
|
|
|
def forward(self, x) -> torch.Tensor: |
|
return self.f(x) |
|
|
|
|
|
class ImageUnembedding(torch.nn.Module): |
|
def __init__(self, in_channels=320, out_channels=6) -> None: |
|
super().__init__() |
|
self.gn = torch.nn.GroupNorm(32, in_channels) |
|
self.f = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) |
|
|
|
def forward(self, x) -> torch.Tensor: |
|
return self.f(torch.nn.functional.silu(self.gn(x))) |
|
|
|
|
|
# Improved universal block with fixes from gh:madebyollin |
|
class ConvResblock(torch.nn.Module): |
|
def __init__(self, in_features=320, out_features=320, skip_conv=False, up=False, down=False) -> None: |
|
super().__init__() |
|
self.f_t = torch.nn.Linear(1280, out_features * 2) |
|
|
|
self.gn_1 = torch.nn.GroupNorm(32, in_features) |
|
self.f_1 = torch.nn.Conv2d(in_features, out_features, kernel_size=3, padding=1) |
|
|
|
self.gn_2 = torch.nn.GroupNorm(32, out_features) |
|
|
|
self.f_2 = torch.nn.Conv2d(out_features, out_features, kernel_size=3, padding=1) |
|
|
|
self.skip_conv = skip_conv |
|
self.f_s = torch.nn.Identity() if not skip_conv else torch.nn.Conv2d(in_features, out_features, kernel_size=1, padding=0) |
|
|
|
self.f_x = torch.nn.Identity() |
|
self.up = up |
|
self.down = down |
|
assert not (up and down), "Can't be up and down at the same time!" |
|
if up: |
|
# torch.nn.functional.upsample_nearest(gn_1, scale_factor=2) |
|
self.f_x = torch.nn.UpsamplingNearest2d(scale_factor=2) |
|
elif down: |
|
# torch.nn.functional.avg_pool2d(gn_1, kernel_size=(2, 2), stride=None) |
|
self.f_x = torch.nn.AvgPool2d(kernel_size=(2, 2), stride=None) |
|
|
|
def forward(self, x, t): |
|
x_skip = x |
|
t: torch.Tensor = self.f_t(torch.nn.functional.silu(t)) |
|
t = t.chunk(2, dim=1) |
|
# ??? |
|
# maybe need to swap them out idk, idxs are like that, first one is +1, other is as is |
|
# probably that stupid while loop with `None`s |
|
t_1 = t[0].unsqueeze(dim=2).unsqueeze(dim=3) + 1 |
|
t_2 = t[1].unsqueeze(dim=2).unsqueeze(dim=3) |
|
|
|
gn_1 = torch.nn.functional.silu(self.gn_1(x)) |
|
f_1 = self.f_1(self.f_x(gn_1)) |
|
|
|
gn_2 = self.gn_2(f_1) |
|
|
|
# I don't know how addcmul is routed, probably += a*b? self is t_2, tensor1 is gn_2, tensor2 is t_1 |
|
addcmul = torch.nn.functional.silu(gn_2 * t_1 + t_2) |
|
return self.f_s(self.f_x(x_skip)) + self.f_2(addcmul) |
|
|
|
|
|
# ConsistencyDecoder aka super resolution from 4 to 3 channels! |
|
class ConsistencyDecoder(torch.nn.Module): |
|
def __init__(self) -> None: |
|
super().__init__() |
|
self.embed_image = ImageEmbedding() |
|
self.embed_time = TimestepEmbedding() |
|
|
|
# No attention is needed here! |
|
# We only "upscale" (48x that is or 64x if you don't count chan diff lulw) |
|
# I was close to doing that, |
|
# but I had CrossAttn over VAE reshaped |
|
# to be Bx(HW div by whatever or -1 if you prefer)x1024 alongside DiffNeXt's skip |
|
|
|
# 3 ResBlocks before downsample |
|
# repeat 4 times |
|
# downs are [320, 640, 1024, 1024] |
|
# in reality it has distinctions between conv and downsamp |
|
# Chess Battle Advanced |
|
down_0 = torch.nn.ModuleList([ |
|
ConvResblock(320, 320), |
|
ConvResblock(320, 320), |
|
ConvResblock(320, 320), |
|
# Downsample(320), |
|
ConvResblock(320, 320, down=True), |
|
]) |
|
down_1 = torch.nn.ModuleList([ |
|
ConvResblock(320, 640, skip_conv=True), |
|
ConvResblock(640, 640), |
|
ConvResblock(640, 640), |
|
# Downsample(640), |
|
ConvResblock(640, 640, down=True), |
|
]) |
|
down_2 = torch.nn.ModuleList([ |
|
ConvResblock(640, 1024, skip_conv=True), |
|
ConvResblock(1024, 1024), |
|
ConvResblock(1024, 1024), |
|
# Downsample(1024), |
|
ConvResblock(1024, 1024, down=True), |
|
]) |
|
down_3 = torch.nn.ModuleList([ |
|
ConvResblock(1024, 1024), |
|
ConvResblock(1024, 1024), |
|
ConvResblock(1024, 1024), |
|
]) |
|
self.down = torch.nn.ModuleList([ |
|
down_0, |
|
down_1, |
|
down_2, |
|
down_3, |
|
]) |
|
|
|
# mid has 2 |
|
self.mid = torch.nn.ModuleList([ |
|
ConvResblock(1024, 1024), |
|
ConvResblock(1024, 1024), |
|
]) |
|
|
|
# Again, |
|
# Chess Battle Advanced |
|
up_3 = torch.nn.ModuleList([ |
|
ConvResblock(1024*2, 1024, skip_conv=True), |
|
ConvResblock(1024*2, 1024, skip_conv=True), |
|
ConvResblock(1024*2, 1024, skip_conv=True), |
|
ConvResblock(1024*2, 1024, skip_conv=True), |
|
# Upsample(1024), |
|
ConvResblock(1024, 1024, up=True), |
|
]) |
|
up_2 = torch.nn.ModuleList([ |
|
ConvResblock(1024*2, 1024, skip_conv=True), |
|
ConvResblock(1024*2, 1024, skip_conv=True), |
|
ConvResblock(1024*2, 1024, skip_conv=True), |
|
ConvResblock(1024+640, 1024, skip_conv=True), |
|
# Upsample(1024), |
|
ConvResblock(1024, 1024, up=True), |
|
]) |
|
up_1 = torch.nn.ModuleList([ |
|
ConvResblock(1024+640, 640, skip_conv=True), |
|
ConvResblock(640*2, 640, skip_conv=True), |
|
ConvResblock(640*2, 640, skip_conv=True), |
|
ConvResblock(320+640, 640, skip_conv=True), |
|
# Upsample(640), |
|
ConvResblock(640, 640, up=True), |
|
]) |
|
up_0 = torch.nn.ModuleList([ |
|
ConvResblock(320+640, 320, skip_conv=True), |
|
ConvResblock(320*2, 320, skip_conv=True), |
|
ConvResblock(320*2, 320, skip_conv=True), |
|
ConvResblock(320*2, 320, skip_conv=True), |
|
]) |
|
self.up = torch.nn.ModuleList([ |
|
up_0, |
|
up_1, |
|
up_2, |
|
up_3, |
|
]) |
|
|
|
# ImageUnembedding |
|
self.output = ImageUnembedding() |
|
|
|
def forward(self, x, t, features) -> torch.Tensor: |
|
t = self.embed_time(t) |
|
# LITERAL SUPER RESOLUTION |
|
x = torch.cat( |
|
# warnings.warn("nn.functional.upsample_nearest is deprecated. Use nn.functional.interpolate instead.") |
|
# [x, torch.nn.functional.upsample_nearest(features, scale_factor=8)], |
|
[x, torch.nn.functional.interpolate(features, scale_factor=8, mode='nearest')], |
|
dim=1 |
|
) |
|
x = self.embed_image(x) |
|
|
|
# DOWN |
|
block_outs = [x] |
|
for mod in self.down: |
|
for f in mod: |
|
x = f(x, t) |
|
block_outs.append(x) |
|
|
|
|
|
# mid |
|
for f in self.mid: |
|
x = f(x, t) |
|
|
|
# UP |
|
for mod in self.up[::-1]: |
|
for f in mod: |
|
if not f.up: |
|
x = torch.concat([x, block_outs.pop()], dim=1) |
|
x = f(x, t) |
|
|
|
# OUT |
|
# GN -> silu -> f |
|
x = self.output(x) |
|
return x |
|
|
|
if __name__ == "__main__": |
|
model = ConsistencyDecoder() |
|
print(model.state_dict().keys(), model.embed_time.emb.weight.shape) |
|
|
|
import safetensors.torch |
|
cd_orig = safetensors.torch.load_file("consistency_decoder.safetensors") |
|
# print(cd_orig.keys()) |
|
|
|
# prefix |
|
cd_orig = {k.replace("blocks.", ""): v for k,v in cd_orig.items()} |
|
|
|
# layer names |
|
cd_orig = {k.replace("down_0_", "down.0."): v for k,v in cd_orig.items()} |
|
cd_orig = {k.replace("down_1_", "down.1."): v for k,v in cd_orig.items()} |
|
cd_orig = {k.replace("down_2_", "down.2."): v for k,v in cd_orig.items()} |
|
cd_orig = {k.replace("down_3_", "down.3."): v for k,v in cd_orig.items()} |
|
|
|
cd_orig = {k.replace("up_0_", "up.0."): v for k,v in cd_orig.items()} |
|
cd_orig = {k.replace("up_1_", "up.1."): v for k,v in cd_orig.items()} |
|
cd_orig = {k.replace("up_2_", "up.2."): v for k,v in cd_orig.items()} |
|
cd_orig = {k.replace("up_3_", "up.3."): v for k,v in cd_orig.items()} |
|
cd_orig = {k.replace("up_4_", "up.4."): v for k,v in cd_orig.items()} |
|
|
|
cd_orig = {k.replace("conv_0.", "0."): v for k,v in cd_orig.items()} |
|
cd_orig = {k.replace("conv_1.", "1."): v for k,v in cd_orig.items()} |
|
cd_orig = {k.replace("conv_2.", "2."): v for k,v in cd_orig.items()} |
|
cd_orig = {k.replace("conv_3.", "3."): v for k,v in cd_orig.items()} |
|
|
|
cd_orig = {k.replace("upsamp.", "4."): v for k,v in cd_orig.items()} |
|
cd_orig = {k.replace("downsamp.", "3."): v for k,v in cd_orig.items()} |
|
|
|
cd_orig = {k.replace("mid_0", "mid.0"): v for k,v in cd_orig.items()} |
|
cd_orig = {k.replace("mid_1", "mid.1"): v for k,v in cd_orig.items()} |
|
|
|
# conv+linear |
|
cd_orig = {k.replace("f_t.w", "f_t.weight").replace("f_t.b", "f_t.bias"): v for k,v in cd_orig.items()} |
|
cd_orig = {k.replace("f_1.w", "f_1.weight").replace("f_1.b", "f_1.bias"): v for k,v in cd_orig.items()} |
|
cd_orig = {k.replace("f_2.w", "f_2.weight").replace("f_2.b", "f_2.bias"): v for k,v in cd_orig.items()} |
|
cd_orig = {k.replace("f_s.w", "f_s.weight").replace("f_s.b", "f_s.bias"): v for k,v in cd_orig.items()} |
|
cd_orig = {k.replace("f.w", "f.weight").replace("f.b", "f.bias"): v for k,v in cd_orig.items()} |
|
|
|
# GN |
|
cd_orig = {k.replace("gn_1.g", "gn_1.weight").replace("gn_1.b", "gn_1.bias"): v for k,v in cd_orig.items()} |
|
cd_orig = {k.replace("gn_2.g", "gn_2.weight").replace("gn_2.b", "gn_2.bias"): v for k,v in cd_orig.items()} |
|
cd_orig = {k.replace("gn.g", "gn.weight").replace("gn.b", "gn.bias"): v for k,v in cd_orig.items()} |
|
|
|
print(cd_orig.keys()) |
|
|
|
cd_orig["embed_time.emb.weight"] = safetensors.torch.load_file("embedding.safetensors")["weight"] |
|
|
|
model.load_state_dict(cd_orig) |
|
|
|
print(cd_orig["embed_time.emb.weight"][1][0]) |
|
|
|
def round_timesteps( |
|
timesteps, total_timesteps, n_distilled_steps, truncate_start=True |
|
): |
|
with torch.no_grad(): |
|
space = torch.div(total_timesteps, n_distilled_steps, rounding_mode="floor") |
|
rounded_timesteps = ( |
|
torch.div(timesteps, space, rounding_mode="floor") + 1 |
|
) * space |
|
if truncate_start: |
|
rounded_timesteps[rounded_timesteps == total_timesteps] -= space |
|
else: |
|
rounded_timesteps[rounded_timesteps == total_timesteps] -= space |
|
rounded_timesteps[rounded_timesteps == 0] += space |
|
return rounded_timesteps |
|
|
|
ts = round_timesteps( |
|
torch.arange(0, 1024), |
|
1024, |
|
64, |
|
truncate_start=False, |
|
) |
|
|
|
print(ts[0], ts.shape) |
|
|
|
# model.forward(torch.zeros(1, 3, 256, 256), torch.zeros(1, dtype=torch.int), torch.zeros(1, 4, 256//8, 256//8)) |
|
model.forward(torch.zeros(1, 3, 256, 256), torch.tensor([ts[0].item()] * 1), torch.zeros(1, 4, 256//8, 256//8)) |
|
|
|
safetensors.torch.save_file(model.state_dict(), "stk_consistency_decoder_amalgamated.safetensors") |
Thanks for this! Made some minor fixes & it seems to work https://gist.github.com/madebyollin/865fa6a18d9099351ddbdfbe7299ccbf