Created
January 18, 2022 06:22
-
-
Save afiaka87/c1be1571043d8dc9c267e0b58195dfbc to your computer and use it in GitHub Desktop.
Finetune GLIDE on a captioned-images dataset e.g. COCO/LAION
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://wandb.ai/afiaka87/glide_finetune/runs/3fj69lfc?workspace=user-afiaka87 | |
from lzma import MODE_NORMAL | |
from PIL import Image | |
import os | |
import wandb | |
from IPython.display import display | |
import torch as th | |
from glide_text2im import xf | |
from glide_text2im.download import load_checkpoint | |
from glide_text2im.model_creation import ( | |
create_model_and_diffusion, | |
model_and_diffusion_defaults, | |
model_and_diffusion_defaults_upsampler | |
) | |
import torch as th | |
import numpy as np | |
from PIL import Image | |
from loader import TextImageDataset | |
import bitsandbytes as bnb | |
from tqdm import trange, tqdm | |
import gc | |
from ipywidgets import Output | |
from IPython.display import display | |
from matplotlib import pyplot as plt | |
from IPython.display import clear_output | |
from torch.cuda.amp import autocast | |
# import glide_text2im | |
# %% | |
has_cuda = th.cuda.is_available() | |
fp16 = False # fp16 is bad for this. perhaps due to low batch size/high noise schedule? | |
device = th.device('cpu' if not has_cuda else 'cuda') | |
# %% | |
# Create base model. | |
options = model_and_diffusion_defaults() | |
options['use_fp16'] = False | |
options['cache_text_emb'] = False | |
# options['use_checkpoint'] = True | |
options['use_fp16'] = has_cuda and fp16 | |
options['dropout'] = 0.1 | |
options['timestep_respacing'] = '100' | |
# use 100 diffusion steps for fast sampling | |
model, diffusion = create_model_and_diffusion(**options) | |
model.train() | |
model.requires_grad_(True) | |
# model.transformer.requires_grad_(True) | |
# model.train() | |
if has_cuda and fp16: | |
model.convert_to_fp16() | |
model.to(device) | |
model.load_state_dict(load_checkpoint('base', device)) | |
print('total base parameters', sum(x.numel() for x in model.parameters() if x.requires_grad)) | |
print(f'transformer params: {sum(x.numel() for x in model.transformer.parameters() if x.requires_grad)}') | |
# %% | |
def show_images(batch: th.Tensor): | |
""" Display a batch of images inline. """ | |
scaled = ((batch + 1)*127.5).round().clamp(0,255).to(th.uint8).cpu() | |
reshaped = scaled.permute(2, 0, 3, 1).reshape([batch.shape[2], -1, 3]) | |
display(Image.fromarray(reshaped.numpy())) | |
# %% | |
batch_size = 1 | |
grad_acc = 4 | |
guidance_scale = 3.0 | |
learning_rate = 1e-6 | |
side_x = 64 | |
side_y = 64 | |
upsample_x = 4 | |
base_dir = './finetune_checkpoints' | |
os.makedirs(base_dir, exist_ok=True) | |
device = th.device('cuda' if th.cuda.is_available() else 'cpu') | |
dataset = TextImageDataset( | |
folder="/home/samsepiol/DatasetWorkspace/CurrentDatasets", | |
shuffle=True, | |
batch_size=batch_size, | |
device=device, | |
) | |
assert len(dataset) > 0, "Dataset is empty" | |
print(f"Dataset contains {len(dataset)} images") | |
def _extract_into_tensor(arr, timesteps, broadcast_shape): | |
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() | |
while len(res.shape) < len(broadcast_shape): | |
res = res[..., None] | |
return res.expand(broadcast_shape) | |
print(f"Dataset has {len(dataset)} images") | |
dataloader = th.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0) | |
print(f"Dataset has {len(dataloader)} batches") | |
def prompt_to_model_kwargs(prompt: str = '', _batch_size: int = 1, device: str = 'cpu'): | |
prompt = prompt.lower() | |
assert len(prompt) > 0, 'prompt must be a non-empty string' | |
tokens = model.tokenizer.encode(prompt) | |
tokens, mask = model.tokenizer.padded_tokens_and_mask(tokens, options['text_ctx']) | |
uncond_tokens, uncond_mask = model.tokenizer.padded_tokens_and_mask([], options['text_ctx']) | |
return dict( | |
tokens=th.tensor( | |
[tokens] * _batch_size + | |
[uncond_tokens] * _batch_size, | |
device=device | |
), | |
mask=th.tensor( | |
[mask] * _batch_size + | |
[uncond_mask] * _batch_size, | |
dtype=th.bool, | |
device=device | |
), | |
) | |
optim = bnb.optim.Adam8bit([x for x in model.parameters() if x.requires_grad], lr=learning_rate) | |
out = Output() | |
display(out) | |
losses = [] | |
l = 0 | |
# bar = trange(train_steps) | |
full_batch_size = batch_size * 2 | |
config = { | |
'batch_size': batch_size, | |
'grad_acc': grad_acc, | |
'side_x': side_x, | |
'side_y': side_y, | |
'learning_rate': learning_rate, | |
} | |
log = {} | |
wandb_run = wandb.init(project="glide_finetune", config=config) | |
try: | |
for i, (captions, images) in tqdm(enumerate(dataloader), total=len(dataloader)): | |
images = images.to(device) | |
for prompt, x in zip(captions, images): | |
x = x.repeat((full_batch_size, 1, 1, 1)) | |
model_kwargs = prompt_to_model_kwargs(prompt=prompt,_batch_size=batch_size, device=device) | |
ts = th.randint(0, 99, (full_batch_size,)).to(device) | |
noise_variance = _extract_into_tensor(diffusion.betas, ts, x.shape) | |
orig_noise = th.randn_like(x, device=x.device) | |
noise = (noise_variance ** 0.5).to(x.device) * orig_noise | |
output = model(x + noise, ts * 10, **model_kwargs) | |
eps = output[..., :3, :, :] | |
loss = th.nn.functional.mse_loss(eps, orig_noise) | |
l += loss.item() | |
loss.backward() | |
if i % 1000 == 0: | |
model.state_dict() | |
model_dict = { | |
'weights': model.state_dict(), | |
'optim': optim.state_dict(), | |
'options': options, | |
} | |
th.save(model_dict, os.path.join(base_dir, f'glide-ft-{i}.pt')) | |
th.save(model_dict, os.path.join(base_dir, f'glide-ft.pt')) | |
print(f'Saved checkpoint {i} to {base_dir}/glide-ft-{i}.pt') | |
if i % grad_acc == grad_acc - 1: | |
optim.step() | |
optim.zero_grad() | |
l /= grad_acc | |
losses.append(l) | |
with out: | |
clear_output(wait=True) | |
wandb_run.log({"loss": l}) | |
l = 0 # TODO important, otherwise it will accumulate | |
except KeyboardInterrupt: | |
pass | |
print("Interrupted") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment