Skip to content

Instantly share code, notes, and snippets.

@cloneofsimo
Created July 24, 2024 14:06
Show Gist options
  • Save cloneofsimo/8ec3e6e41f713bc896e80285ff1fd842 to your computer and use it in GitHub Desktop.
Save cloneofsimo/8ec3e6e41f713bc896e80285ff1fd842 to your computer and use it in GitHub Desktop.
AuraFlow v0.2, sampling that handles self-unconditioning CFG
#### Inference utils
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from diffusers.image_processor import VaeImageProcessor
from diffusers.models import AutoencoderKL
import torch
from tqdm import tqdm
class Fp32LayerNorm(nn.LayerNorm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, input):
output = F.layer_norm(
input.float(),
self.normalized_shape,
self.weight.float() if self.weight is not None else None,
self.bias.float() if self.bias is not None else None,
self.eps,
)
return output.type_as(input)
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def find_multiple(n: int, k: int) -> int:
if n % k == 0:
return n
return n + k - (n % k)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim=None) -> None:
super().__init__()
if hidden_dim is None:
hidden_dim = 4 * dim
n_hidden = int(2 * hidden_dim / 3)
n_hidden = find_multiple(n_hidden, 256)
self.c_fc1 = nn.Linear(dim, n_hidden, bias=False)
self.c_fc2 = nn.Linear(dim, n_hidden, bias=False)
self.c_proj = nn.Linear(n_hidden, dim, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
x = self.c_proj(x)
return x
class MultiHeadLayerNorm(nn.Module):
def __init__(self, hidden_size=None, eps=1e-5):
# Copy pasta from https://github.com/huggingface/transformers/blob/e5f71ecaae50ea476d1e12351003790273c4b2ed/src/transformers/models/cohere/modeling_cohere.py#L78
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
mean = hidden_states.mean(-1, keepdim=True)
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
hidden_states = (hidden_states - mean) * torch.rsqrt(
variance + self.variance_epsilon
)
hidden_states = self.weight.to(torch.float32) * hidden_states
return hidden_states.to(input_dtype)
class SingleAttention(nn.Module):
def __init__(self, dim, n_heads, mh_qknorm=False):
super().__init__()
self.n_heads = n_heads
self.head_dim = dim // n_heads
# this is for cond
self.w1q = nn.Linear(dim, dim, bias=False)
self.w1k = nn.Linear(dim, dim, bias=False)
self.w1v = nn.Linear(dim, dim, bias=False)
self.w1o = nn.Linear(dim, dim, bias=False)
self.q_norm1 = (
MultiHeadLayerNorm((self.n_heads, self.head_dim))
if mh_qknorm
else Fp32LayerNorm(self.head_dim, bias=False, elementwise_affine=False)
)
self.k_norm1 = (
MultiHeadLayerNorm((self.n_heads, self.head_dim))
if mh_qknorm
else Fp32LayerNorm(self.head_dim, bias=False, elementwise_affine=False)
)
def forward(self, c):
bsz, seqlen1, _ = c.shape
q, k, v = self.w1q(c), self.w1k(c), self.w1v(c)
q = q.view(bsz, seqlen1, self.n_heads, self.head_dim)
k = k.view(bsz, seqlen1, self.n_heads, self.head_dim)
v = v.view(bsz, seqlen1, self.n_heads, self.head_dim)
q, k = self.q_norm1(q), self.k_norm1(k)
output = F.scaled_dot_product_attention(
q.permute(0, 2, 1, 3),
k.permute(0, 2, 1, 3),
v.permute(0, 2, 1, 3),
dropout_p=0.0,
is_causal=False,
scale=1 / self.head_dim**0.5,
).permute(0, 2, 1, 3)
output = output.flatten(-2)
c = self.w1o(output)
return c
class DoubleAttention(nn.Module):
def __init__(self, dim, n_heads, mh_qknorm=False):
super().__init__()
self.n_heads = n_heads
self.head_dim = dim // n_heads
# this is for cond
self.w1q = nn.Linear(dim, dim, bias=False)
self.w1k = nn.Linear(dim, dim, bias=False)
self.w1v = nn.Linear(dim, dim, bias=False)
self.w1o = nn.Linear(dim, dim, bias=False)
# this is for x
self.w2q = nn.Linear(dim, dim, bias=False)
self.w2k = nn.Linear(dim, dim, bias=False)
self.w2v = nn.Linear(dim, dim, bias=False)
self.w2o = nn.Linear(dim, dim, bias=False)
self.q_norm1 = (
MultiHeadLayerNorm((self.n_heads, self.head_dim))
if mh_qknorm
else Fp32LayerNorm(self.head_dim, bias=False, elementwise_affine=False)
)
self.k_norm1 = (
MultiHeadLayerNorm((self.n_heads, self.head_dim))
if mh_qknorm
else Fp32LayerNorm(self.head_dim, bias=False, elementwise_affine=False)
)
self.q_norm2 = (
MultiHeadLayerNorm((self.n_heads, self.head_dim))
if mh_qknorm
else Fp32LayerNorm(self.head_dim, bias=False, elementwise_affine=False)
)
self.k_norm2 = (
MultiHeadLayerNorm((self.n_heads, self.head_dim))
if mh_qknorm
else Fp32LayerNorm(self.head_dim, bias=False, elementwise_affine=False)
)
def forward(self, c, x):
bsz, seqlen1, _ = c.shape
bsz, seqlen2, _ = x.shape
seqlen = seqlen1 + seqlen2
cq, ck, cv = self.w1q(c), self.w1k(c), self.w1v(c)
cq = cq.view(bsz, seqlen1, self.n_heads, self.head_dim)
ck = ck.view(bsz, seqlen1, self.n_heads, self.head_dim)
cv = cv.view(bsz, seqlen1, self.n_heads, self.head_dim)
cq, ck = self.q_norm1(cq), self.k_norm1(ck)
xq, xk, xv = self.w2q(x), self.w2k(x), self.w2v(x)
xq = xq.view(bsz, seqlen2, self.n_heads, self.head_dim)
xk = xk.view(bsz, seqlen2, self.n_heads, self.head_dim)
xv = xv.view(bsz, seqlen2, self.n_heads, self.head_dim)
xq, xk = self.q_norm2(xq), self.k_norm2(xk)
# concat all
q, k, v = (
torch.cat([cq, xq], dim=1),
torch.cat([ck, xk], dim=1),
torch.cat([cv, xv], dim=1),
)
output = F.scaled_dot_product_attention(
q.permute(0, 2, 1, 3),
k.permute(0, 2, 1, 3),
v.permute(0, 2, 1, 3),
dropout_p=0.0,
is_causal=False,
scale=1 / self.head_dim**0.5,
).permute(0, 2, 1, 3)
output = output.flatten(-2)
c, x = output.split([seqlen1, seqlen2], dim=1)
c = self.w1o(c)
x = self.w2o(x)
return c, x
class MMDiTBlock(nn.Module):
def __init__(self, dim, heads=8, global_conddim=1024, is_last=False):
super().__init__()
self.normC1 = Fp32LayerNorm(dim, elementwise_affine=False, bias=False)
self.normC2 = Fp32LayerNorm(dim, elementwise_affine=False, bias=False)
if not is_last:
self.mlpC = MLP(dim, hidden_dim=dim * 4)
self.modC = nn.Sequential(
nn.SiLU(),
nn.Linear(global_conddim, 6 * dim, bias=False),
)
else:
self.modC = nn.Sequential(
nn.SiLU(),
nn.Linear(global_conddim, 2 * dim, bias=False),
)
self.normX1 = Fp32LayerNorm(dim, elementwise_affine=False, bias=False)
self.normX2 = Fp32LayerNorm(dim, elementwise_affine=False, bias=False)
self.mlpX = MLP(dim, hidden_dim=dim * 4)
self.modX = nn.Sequential(
nn.SiLU(),
nn.Linear(global_conddim, 6 * dim, bias=False),
)
self.attn = DoubleAttention(dim, heads)
self.is_last = is_last
def forward(self, c, x, global_cond, **kwargs):
cres, xres = c, x
cshift_msa, cscale_msa, cgate_msa, cshift_mlp, cscale_mlp, cgate_mlp = (
self.modC(global_cond).chunk(6, dim=1)
)
c = modulate(self.normC1(c), cshift_msa, cscale_msa)
# xpath
xshift_msa, xscale_msa, xgate_msa, xshift_mlp, xscale_mlp, xgate_mlp = (
self.modX(global_cond).chunk(6, dim=1)
)
x = modulate(self.normX1(x), xshift_msa, xscale_msa)
# attention
c, x = self.attn(c, x)
c = self.normC2(cres + cgate_msa.unsqueeze(1) * c)
c = cgate_mlp.unsqueeze(1) * self.mlpC(modulate(c, cshift_mlp, cscale_mlp))
c = cres + c
x = self.normX2(xres + xgate_msa.unsqueeze(1) * x)
x = xgate_mlp.unsqueeze(1) * self.mlpX(modulate(x, xshift_mlp, xscale_mlp))
x = xres + x
return c, x
class DiTBlock(nn.Module):
# like MMDiTBlock, but it only has X
def __init__(self, dim, heads=8, global_conddim=1024):
super().__init__()
self.norm1 = Fp32LayerNorm(dim, elementwise_affine=False, bias=False)
self.norm2 = Fp32LayerNorm(dim, elementwise_affine=False, bias=False)
self.modCX = nn.Sequential(
nn.SiLU(),
nn.Linear(global_conddim, 6 * dim, bias=False),
)
self.attn = SingleAttention(dim, heads)
self.mlp = MLP(dim, hidden_dim=dim * 4)
def forward(self, cx, global_cond, **kwargs):
cxres = cx
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX(
global_cond
).chunk(6, dim=1)
cx = modulate(self.norm1(cx), shift_msa, scale_msa)
cx = self.attn(cx)
cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx)
mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp))
cx = gate_mlp.unsqueeze(1) * mlpout
cx = cxres + cx
return cx
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
half = dim // 2
freqs = 1000 * torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half) / half
).to(t.device)
args = t[:, None] * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(
dtype=next(self.parameters()).dtype
)
t_emb = self.mlp(t_freq)
return t_emb
class MMDiT(nn.Module):
def __init__(
self,
in_channels=4,
out_channels=4,
patch_size=2,
dim=2048,
n_layers=8,
n_double_layers=4,
n_heads=4,
global_conddim=1024,
cond_seq_dim=2048,
max_seq=16 * 16,
early_branch_out_index=None
):
super().__init__()
self.t_embedder = TimestepEmbedder(global_conddim)
self.cond_seq_linear = nn.Linear(
cond_seq_dim, dim, bias=False
) # linear for something like text sequence.
self.init_x_linear = nn.Linear(
patch_size * patch_size * in_channels, dim
) # init linear for patchified image.
self.positional_encoding = nn.Parameter(torch.randn(1, max_seq, dim) * 0.1)
self.register_tokens = nn.Parameter(torch.randn(1, 8, dim) * 0.02)
self.double_layers = nn.ModuleList([])
self.single_layers = nn.ModuleList([])
for idx in range(n_double_layers):
self.double_layers.append(
MMDiTBlock(dim, n_heads, global_conddim, is_last=(idx == n_layers - 1))
)
for idx in range(n_double_layers, n_layers):
self.single_layers.append(
DiTBlock(dim, n_heads, global_conddim)
)
self.final_linear = nn.Linear(
dim, patch_size * patch_size * out_channels, bias=False
)
self.modF = nn.Sequential(
nn.SiLU(),
nn.Linear(global_conddim, 2 * dim, bias=False),
)
nn.init.constant_(self.final_linear.weight, 0)
self.out_channels = out_channels
self.patch_size = patch_size
self.n_double_layers = n_double_layers
self.n_layers = n_layers
for pn, p in self.named_parameters():
if ".mod" in pn:
nn.init.constant_(p, 0)
print("zeroed", pn)
# if cond_seq_linear
nn.init.constant_(self.cond_seq_linear.weight, 0)
self.h_max = int(max_seq**0.5)
self.w_max = int(max_seq**0.5)
if early_branch_out_index is not None:
self.early_branch_out_index = early_branch_out_index
self.early_dits = nn.ModuleList([DiTBlock(dim, n_heads, global_conddim) for _ in range(3)])
self.early_linear = nn.Linear(
dim, patch_size * patch_size * out_channels, bias=False
)
@torch.no_grad()
def copy_early_from_final(self):
#self.early_dit.load_state_dict(self.single_layers[-1].state_dict())
for i in range(3):
self.early_dits[-i].load_state_dict(self.single_layers[-i].state_dict())
self.early_linear.load_state_dict(self.final_linear.state_dict())
@torch.no_grad()
def extend_pe(self, init_dim=(16, 16), target_dim=(64, 64)):
# extend pe
pe_data = self.positional_encoding.data.squeeze(0)[: init_dim[0] * init_dim[1]]
pe_as_2d = pe_data.view(init_dim[0], init_dim[1], -1).permute(2, 0, 1)
# now we need to extend this to target_dim. for this we will use interpolation.
# we will use torch.nn.functional.interpolate
pe_as_2d = F.interpolate(
pe_as_2d.unsqueeze(0), size=target_dim, mode="bilinear"
)
pe_new = pe_as_2d.squeeze(0).permute(1, 2, 0).flatten(0, 1)
self.positional_encoding.data = pe_new.unsqueeze(0).contiguous()
self.h_max, self.w_max = target_dim
print("PE extended to", target_dim)
def pe_selection_index_based_on_dim(self, h, w):
h_p, w_p = h // self.patch_size, w // self.patch_size
original_pe_indexes = torch.arange(self.positional_encoding.shape[1])
original_pe_indexes = original_pe_indexes.view(self.h_max, self.w_max)
original_pe_indexes = original_pe_indexes[
self.h_max // 2 - h_p // 2 : self.h_max // 2 + h_p // 2,
self.w_max // 2 - w_p // 2 : self.w_max // 2 + w_p // 2,
]
return original_pe_indexes.flatten()
def unpatchify(self, x, h, w):
c = self.out_channels
p = self.patch_size
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum("nhwpqc->nchpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs
def patchify(self, x):
B, C, H, W = x.size()
x = x.view(
B,
C,
H // self.patch_size,
self.patch_size,
W // self.patch_size,
self.patch_size,
)
x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
return x
def forward(self, x, t, conds, **kwargs):
# patchify x, add PE
b, c, h, w = x.shape
# pe_indexes = self.pe_selection_index_based_on_dim(h, w)
# print(pe_indexes.shape)
x = self.init_x_linear(self.patchify(x)) # B, T_x, D
x = x + self.positional_encoding[:, : x.size(1)]
# process conditions for MMDiT Blocks
c_seq = conds["c_seq"][0:b] # B, T_c, D_c
t = t[0:b]
c = self.cond_seq_linear(c_seq) # B, T_c, D
c = torch.cat([self.register_tokens.repeat(c.size(0), 1, 1), c], dim=1)
global_cond = self.t_embedder(t) # B, D
fshift, fscale = self.modF(global_cond).chunk(2, dim=1)
if len(self.double_layers) > 0:
for layer in self.double_layers:
c, x = layer(c, x, global_cond, **kwargs)
early_x = None
if len(self.single_layers) > 0:
c_len = c.size(1)
cx = torch.cat([c, x], dim=1)
for idx, layer in enumerate(self.single_layers):
cx = layer(cx, global_cond, **kwargs)
if idx == self.early_branch_out_index:
early_cx = cx.detach()
gcd = global_cond.detach()
for eidx in range(3):
early_cx = self.early_dits[eidx](early_cx, gcd)
early_x = early_cx[:, c_len:]
early_x = modulate(early_x, fshift.detach(), fscale.detach())
early_x = self.early_linear(early_x)
early_x = self.unpatchify(early_x, h // self.patch_size, w // self.patch_size)
x = cx[:, c_len:]
x = modulate(x, fshift, fscale)
x = self.final_linear(x)
x = self.unpatchify(x, h // self.patch_size, w // self.patch_size)
return x, early_x
class RFv0_2(torch.nn.Module):
def __init__(self, model, ln=True):
super().__init__()
self.model = model
self.ln = ln
self.stratified = False
@torch.no_grad()
def sample(self, z, cond, null_cond=None, sample_steps=50, cfg=4.0, tff= lambda t : (math.sqrt(3) * t / (1 + (math.sqrt(3) - 1) *t)), use_self_as_nullcond = True, use_student = False):
b = z.size(0)
dt = 1.0 / sample_steps
dt = torch.tensor([dt] * b).to(z.device).view([b, *([1] * len(z.shape[1:]))]).half()
images = [z]
for i in tqdm(range(sample_steps, 0, -1)):
t = tff(i / sample_steps)
next_t = tff((i + 1)/sample_steps)
dt = torch.tensor([next_t - t] * b).to(z.device).view([b, *([1] * len(z.shape[1:]))]).half()
t = torch.tensor([t] * b).to(z.device).half()
vc, vcs = self.model(z, t, cond)
if use_self_as_nullcond:
vc = vcs + cfg * (vc - vcs)
elif null_cond is not None:
vu, vus = self.model(z, t, null_cond)
vc = vu + cfg * (vc - vu)
vcs = vus + cfg * (vcs - vus)
if use_student:
vc = vcs
x = z - i * dt * vc
z = z - dt * vc
images.append(x)
return images
class RFPipelinev2:
def __init__(self, ckpt_path = None, device = "cuda:0", test = False):
self.device = device
if test:
self.model = RFv0_2(
MMDiT(
in_channels=4,
out_channels=4,
dim=32 * 12,
global_conddim=32,
n_layers=36,
n_heads=12,
cond_seq_dim=2048,
max_seq = 64 * 64,
early_branch_out_index=8
),
True,
)
else:
self.model = RFv0_2(
MMDiT(
in_channels=4,
out_channels=4,
dim=256 * 12,
global_conddim=256 * 12,
n_layers=36,
n_heads=12,
cond_seq_dim=2048,
max_seq = 64 * 64,
early_branch_out_index=8
),
True,
)
self.model.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
self.model.eval().to(device).half()
self.t5tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pile-t5-xl", use_fast=True)
self.t5tokenizer.pad_token = self.t5tokenizer.bos_token
t5model = AutoModelForSeq2SeqLM.from_pretrained("EleutherAI/pile-t5-xl").half()
self.t5model = t5model.to(device).eval()
self.vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(device)
@torch.no_grad()
def _make_text_cond(self, cond_prompts, uncond_prompts = None):
uncond_prompts = [""] * len(cond_prompts)
cond_prompts = [f"{prompt}" for prompt in cond_prompts]
allprompts = cond_prompts + uncond_prompts
#print(allprompts)
t5_inputs = self.t5tokenizer(
allprompts,
truncation=True,
max_length=256,
padding="max_length",
return_tensors="pt",
)
t5_inputs = {k: v.to(self.device) for k, v in t5_inputs.items()}
t5_outputs = self.t5model.encoder(**t5_inputs)[0] # B, T, D
# mask that by 0 for padding tokens
mask = t5_inputs["attention_mask"].unsqueeze(-1).expand(t5_outputs.shape)
t5_outputs = t5_outputs * mask
return {'c_seq': t5_outputs[:len(cond_prompts)]}, {'c_seq': t5_outputs[len(cond_prompts):]}
@torch.no_grad()
def __call__(self, prompt, num_inference_steps=50, height = 256, width = 256, guidance_scale = 2.5, num_images_per_prompt = 1, negative_prompt = None, seed = None, use_self_as_uncond = True, use_student= False):
# make cond
cond, null_cond = self._make_text_cond([prompt] * num_images_per_prompt, None)
#print(cond, null_cond)
L = num_images_per_prompt
if seed is not None:
gen = torch.Generator().manual_seed(1)
init_noise = torch.randn(L, 4, height // 8, width // 8, generator = gen).to(self.device).half()
else:
init_noise = torch.randn(L, 4, height // 8, width // 8,).to(self.device).half()
images = self.model.sample(init_noise, cond, null_cond, num_inference_steps, guidance_scale, use_self_as_nullcond = use_self_as_uncond, use_student = use_student)
pil_images = []
for i in range(L):
x = self.vae.decode(images[-1][i : i + 1].to(self.device).float()/0.13025).sample
img = VaeImageProcessor().postprocess(
image=x.detach(), do_denormalize=[True, True]
)[0]
pil_images.append(img)
return pil_images
def to(self, device):
self.device = device
self.model.to(device)
self.t5model.to(device)
self.vae.to(device)
return
if __name__ == "__main__":
rf = RFPipelinev2("/home/ubuntu/geneval/ema1.pt")
images = rf("a gray cat playing a saxophone is inside a large, transparent water tank strapped to the belly of a massive mecha robot, which is stomping down a bustling san francisco street, the mecha has large metal legs and arms with glowing joints and wires, towering over buildings and streetlights, the cat's water tank has bubbles and a soft blue glow, in the sky above, several UFOs are hovering, each with a metallic, disc-like shape and glowing lights underneath, below the mecha, there are elephants of various sizes walking along the street, some are near storefronts, while others are in the middle of the road, causing a commotion among the people, the scene is chaotic with a blend of futuristic elements and everyday city life, capturing a surreal and imaginative moment in vivid detail", num_images_per_prompt=1, height = 1024, width = 1024)
images[0].show()
images[0].save("output.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment