Skip to content

Instantly share code, notes, and snippets.

@iejMac
Last active August 28, 2022 06:32
Show Gist options
  • Save iejMac/041b26ac45cf83f0e574fdbdde2e96c8 to your computer and use it in GitHub Desktop.
Save iejMac/041b26ac45cf83f0e574fdbdde2e96c8 to your computer and use it in GitHub Desktop.
import wandb
import time
import torch
from PIL import Image
import open_clip
from argparse import Namespace
from collections import OrderedDict
from contextlib import suppress
from tqdm import tqdm
from torch import nn
from data import get_data
from distributed import init_distributed_device, is_master
from params import parse_args
from zero_shot import zero_shot_eval
from scheduler import cosine_lr
from torch.profiler import profile, record_function, ProfilerActivity
class MLPCLIP(torch.nn.Module):
def __init__(self, model, img_mlp, txt_mlp, device):
super().__init__()
self.model = model
self.img_mlp = img_mlp
self.txt_mlp = txt_mlp
self.dev = device
def encode_text(self, text):
temp = self.model.encode_text(text)
temp = self.txt_mlp(temp)
return temp
def encode_image(self, image):
temp = self.model.encode_image(image)
temp = self.img_mlp(temp)
return temp
def forward(self, img, txt):
img_feat = self.encode_image(img)
txt_feat = self.encode_text(txt)
return img_feat, txt_feat, self.model.logit_scale
def main():
# Args
args = parse_args()
dev = init_distributed_device(args)
if is_master(args):
pass
# wandb.init(project="h14_distillation", entity="iejmac", name=args.name)
# Model
model_l, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained='laion400m_e32')
model_h, _, preprocess = open_clip.create_model_and_transforms('ViT-H-14') # MAX BS = ~256
# model_h, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32') # MAX BS = ~1536
model_h.set_grad_checkpointing()
d_model_l = 768
d_model_h = 1024
'''
mlp_width = d_model_l * 4
act_layer = nn.GELU
img_mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model_h, mlp_width)),
("gelu", act_layer()),
("c_proj", nn.Linear(mlp_width, d_model_l))
]))
txt_mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model_h, mlp_width)),
("gelu", act_layer()),
("c_proj", nn.Linear(mlp_width, d_model_l))
]))
'''
img_mlp = nn.Linear(d_model_h, d_model_l, bias=False)
txt_mlp = nn.Linear(d_model_h, d_model_l, bias=False)
mlp_model_l = MLPCLIP(model_l, nn.Identity(), nn.Identity(), dev).to(dev)
mlp_model_h = MLPCLIP(model_h, img_mlp, txt_mlp, dev).to(dev)
mlp_model_h = torch.nn.parallel.DistributedDataParallel(mlp_model_h, device_ids=[dev], find_unused_parameters=False)
# Loss and Opt:
loss = nn.MSELoss()
params = mlp_model_h.parameters()
opt = torch.optim.AdamW(
params=params,
lr=args.lr,
betas=(args.beta1, args.beta2),
eps=args.eps,
weight_decay=args.wd,
)
'''
opt = torch.optim.SGD(
params=params,
lr=args.lr,
)
'''
TOTAL_STEPS = 10000
WARMUP = 500
# Scheduler:
scheduler = cosine_lr(opt, args.lr, WARMUP, TOTAL_STEPS)
# Data:
data = get_data(args, (preprocess, preprocess))
data['train'].set_epoch(0)
tr_dat = data["train"].dataloader
step = 1
autocast = torch.cuda.amp.autocast if args.precision == 'amp' else suppress
mlp_model_h.train()
for batch in tr_dat:
if step > TOTAL_STEPS:
break
t0 = time.perf_counter()
metrics = {}
scheduler(step)
metrics.update({"train/lr": opt.param_groups[0]["lr"]})
images, texts = batch
images = images.to(dev, non_blocking=True)
texts = texts.to(dev, non_blocking=True)
t0_l_forward = time.perf_counter()
with torch.no_grad():
with autocast():
l_img_feat, l_txt_feat, l_logit_scale = mlp_model_l(images, texts)
# l_img_feat = model_l.encode_image(images)
# l_txt_feat = model_l.encode_text(texts)
t_l_for = time.perf_counter() - t0_l_forward
metrics.update({"train/l_forward_samples_per_s": images.shape[0]/t_l_for})
t0_h_forward = time.perf_counter()
with autocast():
h_img_feat, h_txt_feat, h_logit_scale = mlp_model_h(images, texts)
loss_img = loss(h_img_feat, l_img_feat) + h_logit_scale - h_logit_scale
loss_txt = loss(h_txt_feat, l_txt_feat) + h_logit_scale - h_logit_scale
total_loss = loss_img + loss_txt
total_loss.backward()
t_h_for_back = time.perf_counter() - t0_h_forward
metrics.update({"train/h_forward_backward_samples_per_s": images.shape[0]/t_h_for_back})
opt.step()
opt.zero_grad()
metrics.update({"train/img_loss": loss_img.item(), "train/txt_loss": loss_txt.item()})
# Zero-shot eval
if step % args.zeroshot_frequency == 0:
mlp_model_h.eval()
zero_shot_metrics = zero_shot_eval(mlp_model_h, data, 0, args)
metrics.update(zero_shot_metrics)
mlp_model_h.train()
# MSE eval
if step % args.val_frequency == 0:
mlp_model_h.eval()
val_dat = data['val'].dataloader
n_batch = 0
tot_img_loss, tot_txt_loss = 0.0, 0.0
with torch.no_grad():
for batch in tqdm(val_dat, unit_scale=args.batch_size):
if n_batch > 3: # TODO: remove
break
val_img, val_txt = batch
val_img, val_txt = val_img.to(dev, non_blocking=True), val_txt.to(dev, non_blocking=True)
n_batch += 1
with autocast():
img_feat, txt_feat, logit_scale = mlp_model_h(val_img, val_txt)
# targ_img_feat, targ_txt_feat = model_l.encode_image(val_img), model_l.encode_text(val_txt)
targ_img_feat, targ_txt_feat, l_log_scale = mlp_model_l(val_img, val_txt)
val_loss_img = loss(img_feat, targ_img_feat)
val_loss_txt = loss(txt_feat, targ_txt_feat)
tot_img_loss += val_loss_img.item()
tot_txt_loss += val_loss_txt.item()
tot_txt_loss /= n_batch
tot_img_loss /= n_batch
eval_metrics = {"val/img_loss": tot_img_loss, "val/txt_loss": tot_txt_loss}
metrics.update(eval_metrics)
mlp_model_h.train()
tf = time.perf_counter()
metrics.update({"train/samples_per_s": images.shape[0]/(tf-t0)})
if is_master(args):
for name, val in metrics.items():
if not (name.startswith("train") or name.startswith("val")):
name = "val/" + name # hack for zero-shot stuff
# wandb.log({name: val}, step=step)
print(name, val)
step += 1
if __name__ == "__main__":
main()
#!/bin/bash
torchrun --nproc_per_node 8 train.py \
--train-data "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/{000000..231348}.tar -" \
--train-num-samples 2000000000 \
--val-data "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/{231349..231349}.tar -" \
--val-num-samples 10000 \
--imagenet-val "/fsx/rom1504/imagenetval/imagenetvalv1" \
--dataset-type "webdataset" \
--batch-size 16 \
--lr 5e-4 \
--beta1 0.9 \
--beta2 0.98 \
--eps 1e-6 \
--wd 0.0 \
--workers 6 \
--epochs 1 \
--zeroshot-frequency 1000 \
--val-frequency 100 \
--name "H=B (grad check)"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment