Created
August 2, 2019 02:55
-
-
Save jayelm/8918f0bb6580dcc9af81f2f842784dd5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# Use this commit: | |
# 82b3c3a106a5d5f6d8afe98c34b301e3ed696865 | |
# https://github.com/AlexKuhnle/ShapeWorld/tree/82b3c3a106a5d5f6d8afe98c34b301e3ed696865 | |
from collections import defaultdict | |
import numpy as np | |
import os | |
from shapeworld import dataset | |
import json | |
import itertools | |
from tqdm import tqdm, trange | |
from shapeworld.world.shape import Shape | |
from shapeworld.world.color import Color | |
# From generators/generator.py | |
SHADE_RANGE = 0.33 | |
SIZE_RANGE = (0.1, 0.2) | |
DISTORTION_RANGE = (2.0, 3.0) | |
from PIL import Image | |
import os | |
def save_world(world, filename): | |
worldarr = world.get_array() | |
worldarr = np.round_(worldarr * 255).astype(np.uint8) | |
img = Image.fromarray(worldarr) | |
img.save(filename) | |
def negate_world(world, save=False): | |
if save: | |
save_world(world, 'test_before.png') | |
entity_names = [str(e) for e in world.entities] | |
modified = False | |
for i, entity_name in enumerate(entity_names): | |
if entity_name in IGNORE_ENTITIES: | |
modified = True | |
# Change this entity. | |
entity = world.entities[i] | |
if random.randint(2): | |
# Modify shape | |
new_shape = None | |
while new_shape is None or (new_shape.name == entity.shape.name): | |
new_shape = Shape.random_instance(Shape.shapes, SIZE_RANGE, DISTORTION_RANGE) | |
entity.shape = new_shape | |
else: | |
# Modify color | |
new_color = None | |
while new_color is None or (new_color.name == entity.color.name): | |
new_color = Color.random_instance(Color.colors, SHADE_RANGE) | |
entity.color = new_color | |
if save: | |
save_world(world, 'test_after.png') | |
if not modified: | |
raise RuntimeError("I didn't modify this world: {}".format(entity_names)) | |
N_TRAIN = 2000 | |
N_VAL = 500 | |
N_TEST = 300 | |
# N_TRAIN = 50 | |
# N_VAL = 25 | |
# N_TEST = 10 | |
N_CAPTIONS = N_TRAIN + N_VAL + N_TEST | |
WIDTH = 64 | |
HEIGHT = 64 | |
CHANNELS = 3 | |
EXAMPLES = 4 | |
DATASET = dataset(dtype="agreement", name="spatial_jda") | |
random = np.random.RandomState(0) | |
IGNORE_ENTITIES_RAW = [ | |
('square', 'red', 'solid'), | |
('rectangle', 'green', 'solid'), | |
('triangle', 'blue', 'solid'), | |
('pentagon', 'yellow', 'solid'), | |
('cross', 'magenta', 'solid'), | |
('circle', 'cyan', 'solid'), | |
('semicircle', 'white', 'solid'), | |
] | |
IGNORE_CAPTION_NAMES = ['{} {}'.format(e[1], e[0]) for e in IGNORE_ENTITIES_RAW] | |
IGNORE_ENTITIES = [str(e).replace("\'", '') for e in IGNORE_ENTITIES_RAW] | |
caption_data = {} | |
test_caption_data = {} | |
trainval_pbar = tqdm(total=N_TRAIN + N_VAL, desc='Train/val captions') | |
test_pbar = tqdm(total=N_TEST, desc='Test captions') | |
# Sample randomly for train/val, skipping forbiddens | |
while (len(caption_data) < (N_TRAIN + N_VAL)) or (len(test_caption_data) < N_TEST): | |
if len(caption_data) > (N_TRAIN + N_VAL): | |
trainval_pbar.set_description('Train/val captions (extra)') | |
DATASET.world_generator.sample_values(mode="train") | |
DATASET.world_captioner.sample_values(mode="train", correct=True) | |
while True: | |
world = DATASET.world_generator() | |
if world is None: | |
continue | |
caption = DATASET.world_captioner(entities=world.entities) | |
if caption is None: | |
continue | |
break | |
realized, = DATASET.caption_realizer.realize(captions=[caption]) | |
realized_str = ' '.join(realized) | |
if any(i in realized_str for i in IGNORE_CAPTION_NAMES): | |
realized = tuple(realized) | |
if realized not in test_caption_data: | |
test_pbar.update(1) | |
test_caption_data[realized] = caption | |
else: | |
realized = tuple(realized) | |
if realized not in caption_data: | |
trainval_pbar.update(1) | |
caption_data[realized] = caption | |
trainval_pbar.close() | |
test_pbar.close() | |
# Compositional split: leave out "red triangles" at train time. Test time is | |
# done with red triangles. Make sure generate makes sense, and inspect the | |
# results! | |
# How to swap shapes | |
# import ipdb; ipdb.set_trace() | |
# print([str(x) for x in world.entities]) | |
# save_world(world, 'test.png') | |
# # Modify entities - modify shape of red triangle | |
# center0 = world.entities[0].center | |
# world.entities[0] = world.entities[1].copy() | |
# world.entities[0].set_center(center0) | |
# save_world(world, 'test2.png') | |
# PROBLEMS: | |
# (1) We want to test systematic generalization which means we need "tricky" | |
# negative examples, e.g. examples where the red triangle is a red shape but | |
# not a triangle, or a triangle but not a red shape. Can we modify the world to | |
# make that happen? | |
# (2) no way to verify that captions are "new" in that they haven't appeared in | |
# some hidden form in the train set (specifically, consider above vs below/left | |
# vs right, I think that's the only issue). This is not a major issue for the | |
# compositional split since we guarantee that entities haven't appear.ed | |
captions = list(sorted(caption_data.keys())) | |
random.shuffle(captions) | |
train_captions = captions[:N_TRAIN] | |
val_captions = captions[N_TRAIN:N_TRAIN+N_VAL] | |
test_captions = list(sorted(test_caption_data.keys())) | |
random.shuffle(test_captions) | |
# Combine | |
caption_data.update(test_caption_data) | |
def has_ignore_entity(world, ignore_entities): | |
""" | |
Return True if the world has any ignore entities | |
""" | |
world_entities = [str(e) for e in world.entities] | |
return any(i in world_entities for i in ignore_entities) | |
def generate(name, captions, n_examples, ignore_entities=None, hard_negatives=False, | |
save=False, save_hard_negatives=None, examples_ratio=20): | |
mappings = defaultdict(list) | |
max_scenes = n_examples * examples_ratio | |
# Here, generate many many scenes. For each scene, check if it agrees with | |
# any of the captions - if so, add. Brute force approach: generate as much | |
# as you can for captions, so you can get positive examples. Negative | |
# examples are easier since | |
total_scenes = 0 | |
pbar = tqdm(total=max_scenes, desc='{} scenes'.format(name)) | |
while total_scenes < max_scenes: | |
DATASET.world_generator.sample_values(mode="train") | |
world = DATASET.world_generator() | |
if world is None: | |
continue | |
if ignore_entities is not None and has_ignore_entity(world, ignore_entities): | |
# Discard world as it contains a forbidden entity | |
continue | |
for key in captions: | |
# Add this image to whichever captions align. Note this means an | |
# image can appear multiple times in a single split | |
# (train/val/test) aligned with different concepts, but no images | |
# will be shared between train/val/test | |
caption = caption_data[key] | |
agree = caption.agreement(entities=world.entities) > 0 | |
if not agree: | |
continue | |
mappings[key].append(world) | |
total_scenes += 1 | |
pbar.update(1) | |
pbar.close() | |
for key in mappings: | |
print(" ".join(key), len(mappings[key])) | |
if save: | |
for key, worlds in mappings.items(): | |
key_dirname = os.path.join('vis', '_'.join(key)[:-2]) | |
os.makedirs(key_dirname, exist_ok=True) | |
for i, world in enumerate(worlds[:5]): | |
worldname = os.path.join(key_dirname, '{}.png'.format(i)) | |
save_world(world, worldname) | |
examples = np.zeros((n_examples, EXAMPLES, WIDTH, HEIGHT, CHANNELS)) | |
inputs = np.zeros((n_examples, WIDTH, HEIGHT, CHANNELS)) | |
labels = np.zeros((n_examples,), dtype=np.uint8) | |
hints = [] | |
test_hints = [] | |
i_example = 0 | |
pbar = tqdm(total=n_examples, desc='{} examples'.format(name)) | |
while i_example < n_examples: | |
key = captions[random.randint(len(captions))] | |
worlds = mappings[key] | |
if len(worlds) < EXAMPLES + 1: | |
continue | |
if save_hard_negatives is not None and i_example < save_hard_negatives: | |
i_dir = os.path.join('vis', '_'.join(key)[:-2]) | |
os.makedirs(i_dir, exist_ok=True) | |
for i_world in range(EXAMPLES): | |
world = worlds.pop() | |
if save_hard_negatives is not None and i_example < save_hard_negatives: | |
save_world(world, os.path.join(i_dir, 'train_{}.png'.format(i_world))) | |
examples[i_example, i_world, ...] = world.get_array() | |
if random.randint(2) == 0: | |
# Positive example: sample from this class. | |
world = worlds.pop() | |
inputs[i_example, ...] = world.get_array() | |
labels[i_example] = 1 | |
test_hint = key | |
if save_hard_negatives is not None and i_example < save_hard_negatives: | |
save_world(world, os.path.join(i_dir, 'test_pos.png')) | |
else: | |
# Negative example: tweak the entity in question. hard_negatives | |
# is True for test captions. figure out which entity has not been | |
# seen (or both). Of those entities, permute either the color or | |
# the shape (this is how we systematically test compositionality). | |
if hard_negatives: | |
assert name == 'test' | |
world = worlds.pop() | |
# Modify world | |
negate_world(world) | |
inputs[i_example, ...] = world.get_array() | |
labels[i_example] = 0 | |
# No clue what the test hint is after negation. | |
test_hint = ('not', ) + key | |
if save_hard_negatives is not None and i_example < save_hard_negatives: | |
save_world(world, os.path.join(i_dir, 'test_neg.png')) | |
else: | |
# Sample randomly from another caption. Note this does NOT | |
# guarantee that the negative example does not belong to the | |
# caption! | |
while True: | |
# Try a different caption | |
other_key = captions[random.randint(len(captions))] | |
# If there are worlds available for this caption, get it | |
if len(mappings[other_key]) > 0: | |
other_world = mappings[other_key].pop() | |
break | |
# Set this as a negative example | |
inputs[i_example, ...] = other_world.get_array() | |
labels[i_example] = 0 | |
# Set the test hint's key | |
test_hint = other_key | |
hints.append(" ".join(key)) | |
test_hints.append(" ".join(test_hint)) | |
i_example += 1 | |
pbar.update(1) | |
pbar.close() | |
print("\n\n") | |
os.makedirs(name, exist_ok=True) | |
np.save(os.path.join(name, "examples.npy"), examples) | |
np.save(os.path.join(name, "inputs.npy"), inputs) | |
np.save(os.path.join(name, "labels.npy"), labels) | |
with open(os.path.join(name, "hints.json"), "w") as hint_f: | |
json.dump(hints, hint_f) | |
with open(os.path.join(name, "test_hints.json"), "w") as t_hint_f: | |
json.dump(test_hints, t_hint_f) | |
generate("train", train_captions, 9000, ignore_entities=IGNORE_ENTITIES) | |
generate("val", val_captions, 1000, ignore_entities=IGNORE_ENTITIES) | |
generate("test", test_captions, 1000, hard_negatives=True, save_hard_negatives=100) | |
# generate("val_same", train_captions, 500) | |
# generate("test_same", train_captions, 500) | |
# generate("train", train_captions, 50, ignore_entities=IGNORE_ENTITIES) | |
# generate("val", val_captions, 50, ignore_entities=IGNORE_ENTITIES) | |
# generate("test", test_captions, 10, hard_negatives=True, save_hard_negatives=5) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment