Skip to content

Instantly share code, notes, and snippets.

@jwatte
Created May 30, 2022 21:10
Show Gist options
  • Save jwatte/2cab43839a1c6912f0e173ac12ee4847 to your computer and use it in GitHub Desktop.
Save jwatte/2cab43839a1c6912f0e173ac12ee4847 to your computer and use it in GitHub Desktop.
# Pick N random annotations from the list of annotations.
# Load/ccale/sample the corresponding image.
# Generate the T5 text embedding of the given prompt.
# Upload the data to the appropriate slot in each tensor.
def load_images(annots, pathsbyid, tensordim, t_embeds, t_masks, t_images):
# pick N random annotations
todo = annots.copy()
random.shuffle(todo)
todo = todo[0: tensordim]
for ix in range(0, tensordim):
atxt = todo[ix]['caption']
iid = todo[ix]['image_id']
imgpath = pathsbyid[iid]
progress("%d, %s, %s\n" % (ix, imgpath, atxt))
img = Image.open(imgpath)
width = img.width
height = img.height
# pick a random square sub-block of the image
# TODO: maybe subsample a little more, for stretching?
if width > height:
img = img.resize((int(256*width/height), 256), Image.BICUBIC)
else:
img = img.resize((256, int(256*height/width)), Image.BICUBIC)
if img.width > 256:
l = int(random.uniform(0, img.width-256))
img = img.crop((l, 0, l+256, 256))
else:
l = int(random.uniform(0, img.height-256))
img = img.crop((0, l, 256, l+256))
t_images[ix, :] = img_to_tensor(img)
text_embeds, text_masks = t5_encode_text([atxt], name=DEFAULT_T5_NAME)
t_embeds[ix, :] = text_embeds
t_masks[ix, :] = text_masks
print(text_masks)
# error I get:
# File "train.py", line 116, in load_images
# t_embeds[ix, :] = text_embeds
# RuntimeError: The expanded size of the tensor (1024) must match the existing size (768) at
# non-singleton dimension 1. Target sizes: [256, 1024]. Tensor sizes: [19, 768]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment