Skip to content

Instantly share code, notes, and snippets.

@piercus
Last active February 21, 2024 11:22
Show Gist options
  • Save piercus/07d03f258907542d312c0c735445e793 to your computer and use it in GitHub Desktop.
Save piercus/07d03f258907542d312c0c735445e793 to your computer and use it in GitHub Desktop.
Study batch discrepancy
from re import S
from numpy import c_
from torch.nn import Conv2d
import torch
from refiners import fluxion
from refiners.fluxion import manual_seed
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import DownBlocks
from torch import Tensor, no_grad
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import StableDiffusion_1
from refiners.foundationals.latent_diffusion.unet import ResidualBlock
latents = torch.randn(1,4,32,32).to("cuda:0")
latents2 = latents.clone()
both_latent = torch.cat(tensors=(latents, latents2)).to("cuda:0")
assert torch.allclose(both_latent[0], both_latent[1])
conv = Conv2d(out_channels=320, in_channels=4, padding=1, kernel_size=3, device="cuda:0")
out = conv(both_latent.to("cuda:0"))
assert torch.allclose(out[0], out[1])
manual_seed(0)
text_embedding = torch.randn(1, 77, 768).to("cuda:0")
step = 0
timestep = torch.randint(0, 999, size=(1, 1)).to("cuda:0")
x_b2 = torch.randn(2, 4, 32, 32).to("cuda:0")
condition_scale = 7.5
sd15 = StableDiffusion_1(device="cuda:0")
unet = sd15.unet
down_blocks = DownBlocks(in_channels=4, device="cuda:0")
residual_block = ResidualBlock(in_channels=320, out_channels=1280, device="cuda:0")
def run_sd15(x, clip_cfg_text_embedding, step = 0, condition_scale = 7.5):
timestep = sd15.solver.timesteps[step].unsqueeze(dim=0)
sd15.set_unet_context(timestep=timestep, clip_text_embedding=clip_cfg_text_embedding)
latents = torch.cat(tensors=(x, x)) # for classifier-free guidance
# scale latents for solvers that need it
latents = sd15.solver.scale_model_input(latents, step=step)
unconditional_prediction, conditional_prediction = sd15.unet(latents).chunk(2)
# classifier-free guidance
predicted_noise = unconditional_prediction + condition_scale * (
conditional_prediction - unconditional_prediction
)
return sd15.solver(x, predicted_noise=predicted_noise, step=step)
def run_unet(x: Tensor, text_embedding: Tensor, timestep: Tensor) -> Tensor:
unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s
unet.set_timestep(timestep=timestep)
return unet(torch.cat(tensors=(x, x)))
def run_unet_simple(x: Tensor, text_embedding: Tensor, timestep: Tensor) -> Tensor:
unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s
unet.set_timestep(timestep=timestep)
return unet(x)
def run_down_blocks(x: Tensor, clip_text_embedding: Tensor) -> Tensor:
down_blocks.set_context("cross_attention_block", {"clip_text_embedding": clip_text_embedding})
down_blocks.set_context("sampling", {"shapes": []})
return down_blocks(x)
def run_down_blocks(x: Tensor, clip_text_embedding: Tensor) -> Tensor:
down_blocks.set_context("cross_attention_block", {"clip_text_embedding": clip_text_embedding})
down_blocks.set_context("sampling", {"shapes": []})
return down_blocks(x)
def run_residual_block(x: Tensor) -> Tensor:
return residual_block(x)
conv2d = fluxion.layers.Conv2d(in_channels=4, out_channels=320, kernel_size=3, padding=1, device="cuda:0")
def run_conv2d(x: Tensor) -> Tensor:
return conv2d(x)
def distance (x: Tensor, y: Tensor) -> float:
return torch.max((x - y).abs()).item()
with no_grad():
prompt1 = "a cute cat, detailed high-quality professional image"
negative_prompt1 = "lowres, bad anatomy, bad hands, cropped, worst quality"
prompt2 = prompt1
negative_prompt2 = negative_prompt1
clip_text_embedding_cfg_b2 = sd15.compute_clip_text_embedding(
text=[prompt1, prompt2], negative_text=[negative_prompt1, negative_prompt2]
)
# clip_text_embedding_cfg = torch.cat(tensors=(text_embedding.clone(), text_embedding.clone()))
# clip_text_embedding_cfg_b2 = torch.cat(tensors=(clip_text_embedding_cfg.clone(), clip_text_embedding_cfg.clone()))
# x_b2 = torch.cat(tensors=(x.clone(), x.clone()))
clip_text_embedding_cfg_1 = torch.cat(tensors=(clip_text_embedding_cfg_b2[0:1], clip_text_embedding_cfg_b2[2:3]))
clip_text_embedding_cfg_2 = torch.cat(tensors=(clip_text_embedding_cfg_b2[1:2], clip_text_embedding_cfg_b2[3:4]))
print("clip_text_embedding_cfg_b2[0] vs clip_text_embedding_cfg_1 : ", distance(clip_text_embedding_cfg_b2[0], clip_text_embedding_cfg_1[0]))
print("clip_text_embedding_cfg_b2[1] vs clip_text_embedding_cfg_2 : ", distance(clip_text_embedding_cfg_b2[1], clip_text_embedding_cfg_2[0]))
x_b1_1 = x_b2[0:1]
x_b1_2 = x_b2[1:2]
result_b2 = run_sd15(x_b2, clip_text_embedding_cfg_b2)
result_b1_1 = run_sd15(x_b1_1, clip_text_embedding_cfg_1)
print("sd15 : b2[0] vs b1_1 : ", distance(result_b1_1[0], result_b2[0]))
result_b1_2 = run_sd15(x_b1_2, clip_text_embedding_cfg_2)
print("sd15 : b2[1] vs b1_2 : ", distance(result_b1_2[0], result_b2[1]))
latent_b2 = run_unet(x_b2, clip_text_embedding_cfg_b2, timestep)
latent_b1_1 = run_unet(x_b1_1, clip_text_embedding_cfg_1, timestep)
print("unet : latent_b2[0] vs latent_b1_1 : ", distance(latent_b1_1[0], latent_b2[0]))
down_blocks_b2 = run_down_blocks(x_b2, clip_text_embedding_cfg_b2[0:2])
down_blocks_b1_1 = run_down_blocks(x_b1_1, clip_text_embedding_cfg_b2[0:1])
print("down_blocks : down_blocks_b2[0] vs down_blocks_b1_1 : ", distance(down_blocks_b1_1[0], down_blocks_b2[0]))
c_b2 = run_conv2d(x_b2)
c_b1_1 = run_conv2d(x_b1_1)
r_b2 = run_residual_block(c_b2)
r_b1_1 = run_residual_block(c_b1_1)
print("residual_block : r_b2[0] vs r_b1_1 : ", distance(r_b1_1[0], r_b2[0]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment