Last active
February 21, 2024 11:22
-
-
Save piercus/07d03f258907542d312c0c735445e793 to your computer and use it in GitHub Desktop.
Study batch discrepancy
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from re import S | |
from numpy import c_ | |
from torch.nn import Conv2d | |
import torch | |
from refiners import fluxion | |
from refiners.fluxion import manual_seed | |
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import DownBlocks | |
from torch import Tensor, no_grad | |
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import StableDiffusion_1 | |
from refiners.foundationals.latent_diffusion.unet import ResidualBlock | |
latents = torch.randn(1,4,32,32).to("cuda:0") | |
latents2 = latents.clone() | |
both_latent = torch.cat(tensors=(latents, latents2)).to("cuda:0") | |
assert torch.allclose(both_latent[0], both_latent[1]) | |
conv = Conv2d(out_channels=320, in_channels=4, padding=1, kernel_size=3, device="cuda:0") | |
out = conv(both_latent.to("cuda:0")) | |
assert torch.allclose(out[0], out[1]) | |
manual_seed(0) | |
text_embedding = torch.randn(1, 77, 768).to("cuda:0") | |
step = 0 | |
timestep = torch.randint(0, 999, size=(1, 1)).to("cuda:0") | |
x_b2 = torch.randn(2, 4, 32, 32).to("cuda:0") | |
condition_scale = 7.5 | |
sd15 = StableDiffusion_1(device="cuda:0") | |
unet = sd15.unet | |
down_blocks = DownBlocks(in_channels=4, device="cuda:0") | |
residual_block = ResidualBlock(in_channels=320, out_channels=1280, device="cuda:0") | |
def run_sd15(x, clip_cfg_text_embedding, step = 0, condition_scale = 7.5): | |
timestep = sd15.solver.timesteps[step].unsqueeze(dim=0) | |
sd15.set_unet_context(timestep=timestep, clip_text_embedding=clip_cfg_text_embedding) | |
latents = torch.cat(tensors=(x, x)) # for classifier-free guidance | |
# scale latents for solvers that need it | |
latents = sd15.solver.scale_model_input(latents, step=step) | |
unconditional_prediction, conditional_prediction = sd15.unet(latents).chunk(2) | |
# classifier-free guidance | |
predicted_noise = unconditional_prediction + condition_scale * ( | |
conditional_prediction - unconditional_prediction | |
) | |
return sd15.solver(x, predicted_noise=predicted_noise, step=step) | |
def run_unet(x: Tensor, text_embedding: Tensor, timestep: Tensor) -> Tensor: | |
unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s | |
unet.set_timestep(timestep=timestep) | |
return unet(torch.cat(tensors=(x, x))) | |
def run_unet_simple(x: Tensor, text_embedding: Tensor, timestep: Tensor) -> Tensor: | |
unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s | |
unet.set_timestep(timestep=timestep) | |
return unet(x) | |
def run_down_blocks(x: Tensor, clip_text_embedding: Tensor) -> Tensor: | |
down_blocks.set_context("cross_attention_block", {"clip_text_embedding": clip_text_embedding}) | |
down_blocks.set_context("sampling", {"shapes": []}) | |
return down_blocks(x) | |
def run_down_blocks(x: Tensor, clip_text_embedding: Tensor) -> Tensor: | |
down_blocks.set_context("cross_attention_block", {"clip_text_embedding": clip_text_embedding}) | |
down_blocks.set_context("sampling", {"shapes": []}) | |
return down_blocks(x) | |
def run_residual_block(x: Tensor) -> Tensor: | |
return residual_block(x) | |
conv2d = fluxion.layers.Conv2d(in_channels=4, out_channels=320, kernel_size=3, padding=1, device="cuda:0") | |
def run_conv2d(x: Tensor) -> Tensor: | |
return conv2d(x) | |
def distance (x: Tensor, y: Tensor) -> float: | |
return torch.max((x - y).abs()).item() | |
with no_grad(): | |
prompt1 = "a cute cat, detailed high-quality professional image" | |
negative_prompt1 = "lowres, bad anatomy, bad hands, cropped, worst quality" | |
prompt2 = prompt1 | |
negative_prompt2 = negative_prompt1 | |
clip_text_embedding_cfg_b2 = sd15.compute_clip_text_embedding( | |
text=[prompt1, prompt2], negative_text=[negative_prompt1, negative_prompt2] | |
) | |
# clip_text_embedding_cfg = torch.cat(tensors=(text_embedding.clone(), text_embedding.clone())) | |
# clip_text_embedding_cfg_b2 = torch.cat(tensors=(clip_text_embedding_cfg.clone(), clip_text_embedding_cfg.clone())) | |
# x_b2 = torch.cat(tensors=(x.clone(), x.clone())) | |
clip_text_embedding_cfg_1 = torch.cat(tensors=(clip_text_embedding_cfg_b2[0:1], clip_text_embedding_cfg_b2[2:3])) | |
clip_text_embedding_cfg_2 = torch.cat(tensors=(clip_text_embedding_cfg_b2[1:2], clip_text_embedding_cfg_b2[3:4])) | |
print("clip_text_embedding_cfg_b2[0] vs clip_text_embedding_cfg_1 : ", distance(clip_text_embedding_cfg_b2[0], clip_text_embedding_cfg_1[0])) | |
print("clip_text_embedding_cfg_b2[1] vs clip_text_embedding_cfg_2 : ", distance(clip_text_embedding_cfg_b2[1], clip_text_embedding_cfg_2[0])) | |
x_b1_1 = x_b2[0:1] | |
x_b1_2 = x_b2[1:2] | |
result_b2 = run_sd15(x_b2, clip_text_embedding_cfg_b2) | |
result_b1_1 = run_sd15(x_b1_1, clip_text_embedding_cfg_1) | |
print("sd15 : b2[0] vs b1_1 : ", distance(result_b1_1[0], result_b2[0])) | |
result_b1_2 = run_sd15(x_b1_2, clip_text_embedding_cfg_2) | |
print("sd15 : b2[1] vs b1_2 : ", distance(result_b1_2[0], result_b2[1])) | |
latent_b2 = run_unet(x_b2, clip_text_embedding_cfg_b2, timestep) | |
latent_b1_1 = run_unet(x_b1_1, clip_text_embedding_cfg_1, timestep) | |
print("unet : latent_b2[0] vs latent_b1_1 : ", distance(latent_b1_1[0], latent_b2[0])) | |
down_blocks_b2 = run_down_blocks(x_b2, clip_text_embedding_cfg_b2[0:2]) | |
down_blocks_b1_1 = run_down_blocks(x_b1_1, clip_text_embedding_cfg_b2[0:1]) | |
print("down_blocks : down_blocks_b2[0] vs down_blocks_b1_1 : ", distance(down_blocks_b1_1[0], down_blocks_b2[0])) | |
c_b2 = run_conv2d(x_b2) | |
c_b1_1 = run_conv2d(x_b1_1) | |
r_b2 = run_residual_block(c_b2) | |
r_b1_1 = run_residual_block(c_b1_1) | |
print("residual_block : r_b2[0] vs r_b1_1 : ", distance(r_b1_1[0], r_b2[0])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment