Last active
February 20, 2021 15:08
-
-
Save crowsonkb/a93904fbb88aff0302aac98dfdb26b5f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import argparse | |
from collections import defaultdict | |
import csv | |
import math | |
import torch | |
from torch import nn, optim | |
from torch.nn import functional as F | |
from torch.utils import data | |
from torchvision import datasets, transforms | |
from tqdm import tqdm | |
TRAIN_ANN = 'annotations/captions_train2017.json' | |
TRAIN_ROOT = 'train2017-160' | |
VAL_ANN = 'annotations/captions_val2017.json' | |
VAL_ROOT = 'val2017-160' | |
BATCH_SIZE = 2500 | |
MICROBATCH_SIZE = 50 | |
PREFIX = 'clip_coco_2' | |
SEQ_LEN = 630 | |
BYTES = [10, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, | |
52, 53, 54, 55, 56, 57, 58, 59, 61, 62, 63, 64, 91, 92, 93, 95, 96, 97, 98, 99, 100, | |
101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, | |
118, 119, 120, 121, 122] | |
TOKENS = defaultdict(lambda: 0, zip(BYTES, range(1, len(BYTES) + 1))) | |
class ConvBlock(nn.Sequential): | |
def __init__(self, c_in, c_out): | |
super().__init__( | |
nn.Conv2d(c_in, c_out, 3, padding=1), | |
nn.ReLU(inplace=True), | |
) | |
class ImageEncoder(nn.Sequential): | |
def __init__(self): | |
super().__init__( | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
ConvBlock(3, 64), | |
ConvBlock(64, 64), | |
nn.MaxPool2d(2), | |
ConvBlock(64, 128), | |
ConvBlock(128, 128), | |
nn.MaxPool2d(2), | |
ConvBlock(128, 128), | |
ConvBlock(128, 128), | |
nn.MaxPool2d(2), | |
ConvBlock(128, 128), | |
ConvBlock(128, 128), | |
nn.MaxPool2d(2), | |
ConvBlock(128, 128), | |
ConvBlock(128, 128), | |
nn.MaxPool2d(2), | |
nn.AdaptiveAvgPool2d([4, 4]), | |
nn.Flatten(), | |
nn.Linear(128 * 4 * 4, 64), | |
) | |
class TextEncoder(nn.Module): | |
def __init__(self): | |
super().__init__() | |
d_model = 256 | |
self.embed = nn.Embedding(max(TOKENS.values()) + 1, d_model) | |
layer = nn.TransformerEncoderLayer(d_model, 4, d_model) | |
self.encoder = nn.TransformerEncoder(layer, 6) | |
self.proj = nn.Linear(d_model, 64) | |
pos = torch.arange(SEQ_LEN) | |
dim = torch.arange(d_model) | |
pos, dim = torch.meshgrid([pos, dim]) | |
ramp = pos / 10000**(2 * dim / d_model) | |
pe = torch.where(dim % 2 == 0, torch.sin(ramp), torch.cos(ramp)) | |
self.register_buffer('pe', pe) | |
def forward(self, input): | |
mask = (input == 0).T | |
embed = self.embed(input) + self.pe[:, None, :] | |
return self.proj(self.encoder(embed, src_key_padding_mask=mask)[-1]) | |
class CLIPLoss(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.t = nn.Parameter(torch.tensor(0.)) | |
def forward(self, image_embed, text_embed): | |
n = image_embed.shape[0] | |
image_embed = F.normalize(image_embed) | |
text_embed = F.normalize(text_embed) | |
logits = image_embed @ text_embed.T * torch.exp(self.t) | |
labels = torch.arange(n, device=self.t.device) | |
loss_i = F.cross_entropy(logits, labels) | |
loss_t = F.cross_entropy(logits.T, labels) | |
acc_i = (torch.argmax(logits, dim=1) == labels).sum() | |
acc_t = (torch.argmax(logits, dim=0) == labels).sum() | |
return (loss_i + loss_t) / 2, (acc_i + acc_t) / n / 2 | |
def get_seq_len_and_tokens(datasets): | |
seq_len = 0 | |
unique_bytes = set() | |
for dataset in datasets: | |
for item in tqdm(dataset): | |
seq = ' '.join(item[1]).lower().encode() | |
seq_len = max(len(seq), seq_len) | |
for b in seq: | |
unique_bytes.add(b) | |
return seq_len, sorted(unique_bytes) | |
def collate(samples): | |
image_batch = torch.stack([s[0] for s in samples]) | |
texts = [' '.join(s[1]).lower().encode() for s in samples] | |
texts = [list(text.rjust(SEQ_LEN, b'\0')) for text in texts] | |
texts = [[TOKENS[b] for b in text] for text in texts] | |
text_batch = torch.tensor(texts).T | |
return image_batch, text_batch | |
def main(): | |
p = argparse.ArgumentParser() | |
p.add_argument('--seed', type=int, default=0, help='the random seed') | |
args = p.parse_args() | |
torch.manual_seed(args.seed) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
print('Using device:', device) | |
train_tf = transforms.Compose([ | |
transforms.Resize([160, 160]), | |
transforms.RandomCrop([128, 128]), | |
transforms.ToTensor(), | |
]) | |
val_tf = transforms.Compose([ | |
transforms.Resize([160, 160]), | |
transforms.CenterCrop([128, 128]), | |
transforms.ToTensor(), | |
]) | |
train_set = datasets.CocoCaptions(TRAIN_ROOT, TRAIN_ANN, transform=train_tf) | |
train_dl = data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, | |
num_workers=2, collate_fn=collate, pin_memory=True) | |
val_set = datasets.CocoCaptions(VAL_ROOT, VAL_ANN, transform=val_tf) | |
val_dl = data.DataLoader(val_set, batch_size=BATCH_SIZE, | |
num_workers=2, collate_fn=collate, pin_memory=True) | |
image_enc = ImageEncoder().to(device) | |
text_enc = TextEncoder().to(device) | |
clip_loss = CLIPLoss().to(device) | |
print('Image encoder parameters:', sum(x.numel() for x in image_enc.parameters())) | |
print('Text encoder parameters:', sum(x.numel() for x in text_enc.parameters())) | |
print('CLIP loss parameters:', sum(x.numel() for x in clip_loss.parameters())) | |
params = [*image_enc.parameters(), *text_enc.parameters(), *clip_loss.parameters()] | |
opt = optim.AdamW(params, lr=1e-4, weight_decay=0.01) | |
epoch = 1 | |
csvfile = open(PREFIX + '.csv', 'w') | |
writer = csv.writer(csvfile) | |
writer.writerow(['epoch', 'loss', 'accuracy']) | |
csvfile.flush() | |
def train(): | |
image_enc.train() | |
text_enc.train() | |
clip_loss.train() | |
i = 0 | |
for image_batch, text_batch in tqdm(train_dl): | |
i += 1 | |
image_batch = image_batch.to(device, non_blocking=True) | |
text_batch = text_batch.to(device, non_blocking=True) | |
n = math.ceil(BATCH_SIZE / MICROBATCH_SIZE) | |
image_mbs = torch.chunk(image_batch, n) | |
text_mbs = torch.chunk(text_batch, n, dim=1) | |
with torch.no_grad(): | |
images = [image_enc(mb) for mb in image_mbs] | |
texts = [text_enc(mb) for mb in text_mbs] | |
loss, acc = clip_loss(torch.cat(images), torch.cat(texts)) | |
tqdm.write(f'{i} {loss.item():g} {acc.item():g}') | |
opt.zero_grad() | |
for j, mb in enumerate(image_mbs): | |
images_tmp = images.copy() | |
images_tmp[j] = image_enc(mb) | |
loss, _ = clip_loss(torch.cat(images_tmp), torch.cat(texts)) | |
loss.backward() | |
for j, mb in enumerate(text_mbs): | |
texts_tmp = texts.copy() | |
texts_tmp[j] = text_enc(mb) | |
loss, _ = clip_loss(torch.cat(images), torch.cat(texts_tmp)) | |
loss.backward() | |
opt.step() | |
def val(): | |
print('Validating...') | |
image_enc.eval() | |
text_enc.eval() | |
clip_loss.eval() | |
losses, accs = [], [] | |
for image_batch, text_batch in tqdm(val_dl): | |
image_batch = image_batch.to(device, non_blocking=True) | |
text_batch = text_batch.to(device, non_blocking=True) | |
n = math.ceil(BATCH_SIZE / MICROBATCH_SIZE) | |
with torch.no_grad(): | |
images = [image_enc(mb) for mb in torch.chunk(image_batch, n)] | |
texts = [text_enc(mb) for mb in torch.chunk(text_batch, n, dim=1)] | |
loss, acc = clip_loss(torch.cat(images), torch.cat(texts)) | |
losses.append(loss.item() * len(image_batch)) | |
accs.append(acc.item() * len(image_batch)) | |
avg_loss = sum(losses) / len(val_set) | |
avg_acc = sum(accs) / len(val_set) | |
print(f'Validation loss: {avg_loss:g}, accuracy: {avg_acc:g}') | |
writer.writerow([epoch, avg_loss, avg_acc]) | |
csvfile.flush() | |
def save(): | |
state = {'image_enc': image_enc.state_dict(), | |
'text_enc': text_enc.state_dict(), | |
'clip_loss': clip_loss.state_dict(), | |
'opt': opt.state_dict()} | |
torch.save(state, PREFIX + '.pth') | |
print(f'Wrote checkpoint to {PREFIX}.pth.') | |
try: | |
while True: | |
print('Epoch', epoch) | |
train() | |
val() | |
save() | |
epoch += 1 | |
except KeyboardInterrupt: | |
pass | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment