Last active
June 2, 2022 15:15
-
-
Save jwatte/c744cace32961d55465f29123c88a779 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from imagen_pytorch import Unet, Imagen, ImagenTrainer | |
from imagen_pytorch.t5 import t5_encode_text, DEFAULT_T5_NAME | |
from torchvision import transforms | |
# This doesn't yet format the text embedding tensors right | |
# TODO: check out https://gist.github.com/Netruk44/38d793e6d04a53cc4d9acbfadbb04a5c | |
import json | |
from PIL import Image | |
import sys | |
import random | |
img_to_tensor = transforms.ToTensor() | |
def progress(s): | |
sys.stderr.write(s) | |
valortrain = 'val2017' | |
maxtrain = 16 | |
itersperimgset = 50 | |
imgsetspercheckpoint = 50 | |
maxcheckpoints = 50 | |
infile = 'coco/annotations/captions_%s.json' % valortrain | |
progress("Loading JSON from %s.\n" % infile) | |
with open(infile, 'r') as f: | |
annot = json.load(f) | |
numimg = len(annot['images']) | |
numannot = len(annot['annotations']) | |
if numannot < maxtrain: | |
maxtrain = numannot | |
imgpathbyid = {} | |
imglist = annot['images'] | |
for k in range(0, len(imglist)): | |
i = imglist[k] | |
imgpathbyid[i['id']] = 'coco/images/%s/%s' % (valortrain, i['file_name']) | |
progress("Allocating %d slots for %d images and %d annotations from %s\n" % | |
(maxtrain, numimg, numannot, infile)) | |
# mock images (get a lot of this) and text encodings from large T5 | |
text_embeds = torch.zeros(maxtrain, 256, 1024).cuda() | |
text_masks = torch.ones(maxtrain, 256).bool().cuda() | |
images = torch.zeros(maxtrain, 3, 256, 256).cuda() | |
# unet for imagen | |
unet1 = Unet( | |
dim=32, | |
cond_dim=512, | |
dim_mults=(1, 2, 4, 8), | |
num_resnet_blocks=3, | |
layer_attns=(False, True, True, True), | |
) | |
unet2 = Unet( | |
dim=32, | |
cond_dim=512, | |
dim_mults=(1, 2, 4, 8), | |
num_resnet_blocks=(2, 4, 8, 8), | |
layer_attns=(False, False, False, True), | |
layer_cross_attns=(False, False, False, True) | |
) | |
# imagen, which contains the unets above (base unet and super resoluting ones) | |
imagen = Imagen( | |
unets=(unet1, unet2), | |
text_encoder_name='t5-large', | |
image_sizes=(64, 256), | |
beta_schedules=('cosine', 'linear'), | |
timesteps=1000, | |
cond_drop_prob=0.5 | |
).cuda() | |
# wrap imagen with the trainer class | |
trainer = ImagenTrainer(imagen) | |
# Pick N random annotations from the list of annotations. | |
# Load/ccale/sample the corresponding image. | |
# Generate the T5 text embedding of the given prompt. | |
# Upload the data to the appropriate slot in each tensor. | |
def load_images(annots, pathsbyid, tensordim, t_embeds, t_masks, t_images): | |
# pick N random annotations | |
todo = annots.copy() | |
random.shuffle(todo) | |
todo = todo[0: tensordim] | |
for ix in range(0, tensordim): | |
atxt = todo[ix]['caption'] | |
iid = todo[ix]['image_id'] | |
imgpath = pathsbyid[iid] | |
progress("%d, %s, %s\n" % (ix, imgpath, atxt)) | |
img = Image.open(imgpath) | |
width = img.width | |
height = img.height | |
# pick a random square sub-block of the image | |
# TODO: maybe subsample a little more, for stretching? | |
if width > height: | |
img = img.resize((int(256*width/height), 256), Image.BICUBIC) | |
else: | |
img = img.resize((256, int(256*height/width)), Image.BICUBIC) | |
if img.width > 256: | |
l = int(random.uniform(0, img.width-256)) | |
img = img.crop((l, 0, l+256, 256)) | |
else: | |
l = int(random.uniform(0, img.height-256)) | |
img = img.crop((0, l, 256, l+256)) | |
t_images[ix, :] = img_to_tensor(img) | |
text_embeds, text_masks = t5_encode_text([atxt], name=DEFAULT_T5_NAME) | |
t_embeds[ix, :] = text_embeds.cuda() | |
t_masks[ix, :] = text_masks.cuda() | |
del text_embeds | |
del text_masks | |
# feed images into imagen, training each unet in the cascade | |
for cp in range(0, maxcheckpoints): | |
for iset in range(0, imgsetspercheckpoint): | |
progress("checkpoint %d imageset %d iters %d\n" % | |
(cp, iset, itersperimgset)) | |
load_images(annot['annotations'], imgpathbyid, | |
maxtrain, text_embeds, text_masks, images) | |
for iter in range(0, itersperimgset): | |
# train the networks | |
progress("iter %d/%d/%d\n" % (cp, iset, iter)) | |
for i in (1, 2): | |
loss = trainer(images, text_embeds=text_embeds, | |
text_masks=text_masks, unet_number=i) | |
trainer.update(unet_number=i) | |
images = trainer.sample(texts=[ | |
'a puppy looking anxiously at a giant donut on the table', | |
'the milky way galaxy in the style of monet' | |
], cond_scale=2.) | |
print(images.shape) # or whatever |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment