Created
November 11, 2022 19:57
-
-
Save nousr/e780ba0044f4855357b3f3790c10d3d2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import wandb | |
import torch | |
from dalle2_pytorch import ( | |
T5OpenClipAdapter, | |
DiffusionPriorNetwork, | |
DiffusionPrior, | |
) | |
from dalle2_pytorch.trainer import DiffusionPriorTrainer | |
from transformers import T5Tokenizer | |
from accelerate import Accelerator | |
from torch.utils.data import DataLoader | |
from webdataset import WebDataset | |
from dalle2_pytorch.utils import Timer | |
# MODEL KWARGS | |
T5_MODEL = "t5-large" | |
OPEN_CLIP_MODEL = "ViT-B/32" | |
PRE_TRAINED = "laion400m_e32" | |
# DATA KWARGS | |
BATCH_SIZE = 200 | |
DATASET = "/mnt/k/cc3m/cc3m/{00000..00266}.tar" | |
VALIDATION_DATA = "/mnt/k/cc3m/cc3m/00267.tar" | |
# TRAINING KWARGS | |
LR = 1.1e-4 | |
WD = 6e-4 | |
EPOCHS = 3 | |
wandb.init( | |
config={ | |
"t5_model": T5_MODEL, | |
"open_clip_model": OPEN_CLIP_MODEL, | |
"pre_trained": PRE_TRAINED, | |
"batch_size": BATCH_SIZE, | |
"dataset": DATASET, | |
"lr": LR, | |
"wd": WD, | |
"epochs": EPOCHS, | |
}, | |
entity="nousr_laion", | |
project="t5_prior_test" | |
) | |
# TODO: make multi-node | |
def get_data(dataset): | |
wds = WebDataset(dataset).decode("torchrgb").to_tuple("txt", "jpg;png") | |
dataloader = DataLoader(wds, batch_size=BATCH_SIZE, num_workers=10) | |
return dataloader | |
def load_tokenizer(): | |
return T5Tokenizer.from_pretrained(T5_MODEL, model_max_length=77) | |
def load_trainer(accelerator): | |
t5clip = T5OpenClipAdapter( | |
t5_model=T5_MODEL, | |
clip_name=OPEN_CLIP_MODEL, | |
clip_pretrained=PRE_TRAINED, | |
) | |
dpn = DiffusionPriorNetwork( | |
dim=512, | |
dim_head=32, | |
depth=3, | |
heads=4, | |
attn_dropout=5e-1, | |
ff_dropout=5e-1, | |
normformer=True, | |
num_time_embeds=1, | |
num_image_embeds=1, | |
num_text_embeds=1, | |
max_text_len=77, | |
self_cond=False, | |
) | |
prior = DiffusionPrior( | |
net=dpn, | |
clip=t5clip, | |
image_embed_dim=512, | |
timesteps=100, | |
condition_on_text_encodings=False, | |
text_proj_in_dim=1024, | |
norm_text_proj=False, | |
) | |
trainer = DiffusionPriorTrainer( | |
diffusion_prior=prior, | |
accelerator=accelerator, | |
use_ema=True, | |
lr=LR, | |
wd=WD, | |
max_grad_norm=0.5, | |
group_wd_params=True, | |
warmup_steps=50, | |
cosine_decay_max_steps=100, | |
) | |
return trainer | |
def train(trainer, dataloader, validation_dataloader, tokenizer): | |
samples_sec_timer = Timer() | |
for epoch in range(EPOCHS): | |
print(f"Epoch {epoch}") | |
for caption, image in dataloader: | |
trainer.train() | |
samples_sec_timer.reset() | |
current_step = trainer.step.item() | |
# place the batch on the accelerator device | |
tokenized_caption = tokenizer( | |
caption, return_tensors="pt", padding=True, truncation=True | |
).to(trainer.device) | |
input_ids, attention_mask = ( | |
tokenized_caption["input_ids"], | |
tokenized_caption["attention_mask"], | |
) | |
# embed the image | |
image_embedding = trainer.accelerator.unwrap_model( | |
trainer.diffusion_prior | |
).clip.embed_image(image.to(trainer.device))[0] | |
# embed the text | |
text_embedding = trainer.embed_text(input_ids, attention_mask)[0] | |
# train the model | |
loss = trainer(text_embed=text_embedding, image_embed=image_embedding) | |
trainer.update() | |
# get the ema | |
ema_decay = trainer.ema_diffusion_prior.get_current_decay() | |
# get the samples/sec | |
samples_sec_stop = samples_sec_timer.elapsed() | |
samples_sec = BATCH_SIZE / samples_sec_stop | |
# print the loss | |
wandb.log({ | |
"loss": loss, | |
"ema_decay": ema_decay, | |
"samples_sec": samples_sec, | |
"train_step": current_step, | |
}) | |
# print the loss every 100 steps | |
if current_step % 100 == 0: | |
print(f"Step {current_step}: {loss:.5f} | {samples_sec:.2f} samples/sec") | |
# compute validation loss | |
if current_step % 1000 == 0: | |
trainer.eval() | |
with torch.no_grad(): | |
for caption, image in validation_dataloader: | |
tokenized_caption = tokenizer( | |
caption, return_tensors="pt", padding=True, truncation=True | |
).to(trainer.device) | |
input_ids, attention_mask = ( | |
tokenized_caption["input_ids"], | |
tokenized_caption["attention_mask"], | |
) | |
# embed the image | |
image_embedding = trainer.accelerator.unwrap_model( | |
trainer.diffusion_prior | |
).clip.embed_image(image.to(trainer.device))[0] | |
# embed the text | |
text_embedding = trainer.embed_text(input_ids, attention_mask)[0] | |
# train the model | |
loss = trainer(text_embed=text_embedding, image_embed=image_embedding) | |
wandb.log({ | |
"val_loss": loss, | |
"val_step": current_step, | |
}) | |
break | |
def main(): | |
# setup accelerator | |
accelerator = Accelerator() | |
# load the model | |
trainer = load_trainer(accelerator) | |
# load the dataset | |
dataloader = get_data(DATASET) | |
validation_dataloder = get_data(VALIDATION_DATA) | |
# load the tokenizer | |
tokenizer = load_tokenizer() | |
# train the model | |
train(trainer, dataloader, validation_dataloder, tokenizer) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment