Created
September 4, 2022 17:07
-
-
Save ovshake/362591dcb38bd7471df2de55f715cf6e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from diffusers import UNet2DModel, UNet2DConditionModel | |
import torch | |
from torch import nn | |
from torch.utils.data import DataLoader | |
from torchvision import transforms | |
import clip | |
from diffusers import DDPMScheduler | |
from diffusers.optimization import get_cosine_schedule_with_warmup | |
from dataclasses import dataclass | |
from accelerate import Accelerator | |
import os | |
from tqdm import tqdm | |
import torch.nn.functional as F | |
import torch | |
from torch.utils.data import Dataset | |
from PIL import Image | |
import os | |
from glob import glob | |
import numpy as np | |
class StableDataset(Dataset): | |
def __init__(self, root_dir, transforms): | |
self.img_dir = os.path.join(root_dir, "cloth") | |
self.np_embedding_dir = os.path.join(root_dir, "clip_txt_embeddings") | |
self.transforms = transforms | |
self.paths = glob(os.path.join(self.img_dir, "*.jpg")) | |
self.names = os.listdir(self.img_dir) | |
self.names = [x.replace(".jpg", "") for x in self.names] | |
def __len__(self): | |
return len(self.names) | |
def __getitem__(self, index): | |
name = self.names[index] | |
img_path = os.path.join(self.img_dir, f"{name}.jpg") | |
np_embedding_path = os.path.join(self.np_embedding_dir, f"{name}.np.gz") | |
img = Image.open(img_path) | |
img = self.transforms(img) | |
np_embedding = np.loadtxt(np_embedding_path, dtype=np.dtype('float32')) | |
np_embedding = torch.from_numpy(np_embedding).unsqueeze(0) | |
return {"images": img, "np_embedding": np_embedding} | |
device = "cuda" | |
class TrainingConfig: | |
image_size = 128 # the generated image resolution | |
train_batch_size = 1 | |
eval_batch_size = 16 # how many images to sample during evaluation | |
num_epochs = 50 | |
gradient_accumulation_steps = 1 | |
learning_rate = 1e-4 | |
lr_warmup_steps = 500 | |
save_image_epochs = 10 | |
save_model_epochs = 30 | |
mixed_precision = 'fp16' # `no` for float32, `fp16` for automatic mixed precision | |
output_dir = 'ddpm-fashion-128-v0' # the model namy locally and on the HF Hub | |
push_to_hub = False # whether to upload the saved model to the HF Hub | |
hub_private_repo = False | |
overwrite_output_dir = True # overwrite the old model when re-running the notebook | |
seed = 0 | |
config = TrainingConfig() | |
preprocess = transforms.Compose( | |
[ | |
transforms.Resize((config.image_size, config.image_size)), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5], [0.5]), | |
] | |
) | |
model = UNet2DConditionModel(sample_size=config.image_size, | |
in_channels=3, | |
out_channels=3, | |
layers_per_block=2, | |
cross_attention_dim=768).cuda() | |
stable_dataset = StableDataset("/data/dataset/VITON-hD/train/", transforms=preprocess) | |
train_dataloader = DataLoader(stable_dataset, | |
batch_size=config.train_batch_size, | |
shuffle=True, | |
num_workers=2, | |
pin_memory=True) | |
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, tensor_format="pt") | |
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate) | |
lr_scheduler = get_cosine_schedule_with_warmup( | |
optimizer=optimizer, | |
num_warmup_steps=config.lr_warmup_steps, | |
num_training_steps=(len(train_dataloader) * config.num_epochs), | |
) | |
def evaluate(config, epoch, pipeline): | |
# Sample some images from random noise (this is the backward diffusion process). | |
# The default pipeline output type is `List[PIL.Image]` | |
images = pipeline( | |
batch_size = config.eval_batch_size, | |
generator=torch.manual_seed(config.seed), | |
)["sample"] | |
# Make a grid out of the images | |
image_grid = make_grid(images, rows=4, cols=4) | |
# Save the images | |
test_dir = os.path.join(config.output_dir, "samples") | |
os.makedirs(test_dir, exist_ok=True) | |
image_grid.save(f"{test_dir}/{epoch:04d}.png") | |
def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler): | |
# Initialize accelerator and tensorboard logging | |
accelerator = Accelerator( | |
mixed_precision=config.mixed_precision, | |
gradient_accumulation_steps=config.gradient_accumulation_steps, | |
log_with="tensorboard", | |
logging_dir=os.path.join(config.output_dir, "logs") | |
) | |
if accelerator.is_main_process: | |
if config.push_to_hub: | |
repo = init_git_repo(config, at_init=True) | |
accelerator.init_trackers("train_example") | |
# Prepare everything | |
# There is no specific order to remember, you just need to unpack the | |
# objects in the same order you gave them to the prepare method. | |
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( | |
model, optimizer, train_dataloader, lr_scheduler | |
) | |
global_step = 0 | |
model = model.half() | |
# Now you train the model | |
for epoch in range(config.num_epochs): | |
progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process) | |
progress_bar.set_description(f"Epoch {epoch}") | |
for step, batch in enumerate(train_dataloader): | |
clean_images = batch['images'].cuda() | |
embeddings = batch["np_embedding"].cuda() | |
# Sample noise to add to the images | |
noise = torch.randn(clean_images.shape).to(clean_images.device) | |
bs = clean_images.shape[0] | |
# Sample a random timestep for each image | |
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device).long() | |
# Add noise to the clean images according to the noise magnitude at each timestep | |
# (this is the forward diffusion process) | |
noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps) | |
with accelerator.accumulate(model) and torch.autocast(device_type='cuda') and torch.no_grad(): | |
# Predict the noise residual | |
noise_pred = model(noisy_images.half(), encoder_hidden_states=embeddings.half(), timestep=timesteps)["sample"] | |
loss = F.mse_loss(noise_pred, noise) | |
accelerator.backward(loss) | |
accelerator.clip_grad_norm_(model.parameters(), 1.0) | |
optimizer.step() | |
lr_scheduler.step() | |
optimizer.zero_grad() | |
progress_bar.update(1) | |
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step} | |
progress_bar.set_postfix(**logs) | |
accelerator.log(logs, step=global_step) | |
global_step += 1 | |
# After each epoch you optionally sample some demo images with evaluate() and save the model | |
if accelerator.is_main_process: | |
pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler) | |
if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1: | |
evaluate(config, epoch, pipeline) | |
if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1: | |
if config.push_to_hub: | |
push_to_hub(config, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=True) | |
else: | |
pipeline.save_pretrained(config.output_dir) | |
if __name__ == '__main__': | |
train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment