Skip to content

Instantly share code, notes, and snippets.

@nousr
Created November 11, 2022 19:57
Show Gist options
  • Save nousr/e780ba0044f4855357b3f3790c10d3d2 to your computer and use it in GitHub Desktop.
Save nousr/e780ba0044f4855357b3f3790c10d3d2 to your computer and use it in GitHub Desktop.
import wandb
import torch
from dalle2_pytorch import (
T5OpenClipAdapter,
DiffusionPriorNetwork,
DiffusionPrior,
)
from dalle2_pytorch.trainer import DiffusionPriorTrainer
from transformers import T5Tokenizer
from accelerate import Accelerator
from torch.utils.data import DataLoader
from webdataset import WebDataset
from dalle2_pytorch.utils import Timer
# MODEL KWARGS
T5_MODEL = "t5-large"
OPEN_CLIP_MODEL = "ViT-B/32"
PRE_TRAINED = "laion400m_e32"
# DATA KWARGS
BATCH_SIZE = 200
DATASET = "/mnt/k/cc3m/cc3m/{00000..00266}.tar"
VALIDATION_DATA = "/mnt/k/cc3m/cc3m/00267.tar"
# TRAINING KWARGS
LR = 1.1e-4
WD = 6e-4
EPOCHS = 3
wandb.init(
config={
"t5_model": T5_MODEL,
"open_clip_model": OPEN_CLIP_MODEL,
"pre_trained": PRE_TRAINED,
"batch_size": BATCH_SIZE,
"dataset": DATASET,
"lr": LR,
"wd": WD,
"epochs": EPOCHS,
},
entity="nousr_laion",
project="t5_prior_test"
)
# TODO: make multi-node
def get_data(dataset):
wds = WebDataset(dataset).decode("torchrgb").to_tuple("txt", "jpg;png")
dataloader = DataLoader(wds, batch_size=BATCH_SIZE, num_workers=10)
return dataloader
def load_tokenizer():
return T5Tokenizer.from_pretrained(T5_MODEL, model_max_length=77)
def load_trainer(accelerator):
t5clip = T5OpenClipAdapter(
t5_model=T5_MODEL,
clip_name=OPEN_CLIP_MODEL,
clip_pretrained=PRE_TRAINED,
)
dpn = DiffusionPriorNetwork(
dim=512,
dim_head=32,
depth=3,
heads=4,
attn_dropout=5e-1,
ff_dropout=5e-1,
normformer=True,
num_time_embeds=1,
num_image_embeds=1,
num_text_embeds=1,
max_text_len=77,
self_cond=False,
)
prior = DiffusionPrior(
net=dpn,
clip=t5clip,
image_embed_dim=512,
timesteps=100,
condition_on_text_encodings=False,
text_proj_in_dim=1024,
norm_text_proj=False,
)
trainer = DiffusionPriorTrainer(
diffusion_prior=prior,
accelerator=accelerator,
use_ema=True,
lr=LR,
wd=WD,
max_grad_norm=0.5,
group_wd_params=True,
warmup_steps=50,
cosine_decay_max_steps=100,
)
return trainer
def train(trainer, dataloader, validation_dataloader, tokenizer):
samples_sec_timer = Timer()
for epoch in range(EPOCHS):
print(f"Epoch {epoch}")
for caption, image in dataloader:
trainer.train()
samples_sec_timer.reset()
current_step = trainer.step.item()
# place the batch on the accelerator device
tokenized_caption = tokenizer(
caption, return_tensors="pt", padding=True, truncation=True
).to(trainer.device)
input_ids, attention_mask = (
tokenized_caption["input_ids"],
tokenized_caption["attention_mask"],
)
# embed the image
image_embedding = trainer.accelerator.unwrap_model(
trainer.diffusion_prior
).clip.embed_image(image.to(trainer.device))[0]
# embed the text
text_embedding = trainer.embed_text(input_ids, attention_mask)[0]
# train the model
loss = trainer(text_embed=text_embedding, image_embed=image_embedding)
trainer.update()
# get the ema
ema_decay = trainer.ema_diffusion_prior.get_current_decay()
# get the samples/sec
samples_sec_stop = samples_sec_timer.elapsed()
samples_sec = BATCH_SIZE / samples_sec_stop
# print the loss
wandb.log({
"loss": loss,
"ema_decay": ema_decay,
"samples_sec": samples_sec,
"train_step": current_step,
})
# print the loss every 100 steps
if current_step % 100 == 0:
print(f"Step {current_step}: {loss:.5f} | {samples_sec:.2f} samples/sec")
# compute validation loss
if current_step % 1000 == 0:
trainer.eval()
with torch.no_grad():
for caption, image in validation_dataloader:
tokenized_caption = tokenizer(
caption, return_tensors="pt", padding=True, truncation=True
).to(trainer.device)
input_ids, attention_mask = (
tokenized_caption["input_ids"],
tokenized_caption["attention_mask"],
)
# embed the image
image_embedding = trainer.accelerator.unwrap_model(
trainer.diffusion_prior
).clip.embed_image(image.to(trainer.device))[0]
# embed the text
text_embedding = trainer.embed_text(input_ids, attention_mask)[0]
# train the model
loss = trainer(text_embed=text_embedding, image_embed=image_embedding)
wandb.log({
"val_loss": loss,
"val_step": current_step,
})
break
def main():
# setup accelerator
accelerator = Accelerator()
# load the model
trainer = load_trainer(accelerator)
# load the dataset
dataloader = get_data(DATASET)
validation_dataloder = get_data(VALIDATION_DATA)
# load the tokenizer
tokenizer = load_tokenizer()
# train the model
train(trainer, dataloader, validation_dataloder, tokenizer)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment