Last active
July 13, 2020 19:51
-
-
Save vzhong/8affecb2a3382eca3cccf1cc91125bbc to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import itertools | |
train_envs = [ | |
['round plate', 'square plate'], | |
['round flat pot', 'round tall pot'], | |
['square flat pot', 'square tall pot'], | |
['round flat pot', 'square flat pot'], | |
['round tall pot', 'square tall pot'], | |
['round flat pot', 'square flat pot', 'round tall pot', 'square tall pot'], | |
['round flat bin', 'round tall bin'], | |
['square flat bin', 'square tall bin'], | |
['round flat bin', 'square flat bin'], | |
['round tall bin', 'square tall bin'], | |
['round flat bin', 'square flat bin', 'round tall bin', 'square tall bin'], | |
['round plate', 'square plate', 'round flat pot', 'round tall pot'], | |
['round plate', 'square flat pot', 'square tall pot'], | |
['round plate', 'square plate', 'round flat pot', 'square flat pot'], | |
['round plate', 'round tall pot', 'square tall pot'], | |
['round plate', 'round flat pot', 'square flat pot', 'round tall pot', 'square tall pot'], | |
['square plate', 'round flat bin', 'round tall bin'], | |
['round plate', 'square plate', 'square flat bin', 'square tall bin'], | |
['round plate', 'square plate', 'round flat bin', 'square flat bin'], | |
['square plate', 'round tall bin', 'square tall bin'], | |
] | |
eval_envs = [ | |
['round flat pot', 'round tall pot', 'round flat bin'], | |
['round plate', 'flat plate', 'square flat pot', 'square tall pot', 'round flat bin', 'round tall bin'], | |
['round flat pot', 'square flat pot', 'round flat bin', 'round tall bin', 'square flat bin'], | |
['round plate', 'flat plate', 'round tall pot', 'square tall pot', 'round flat bin', 'round tall bin', 'square flat bin', 'square tall bin'], | |
['round flat pot', 'square flat pot', 'round tall pot', 'square tall pot', 'round flat bin'], | |
['round flat pot', 'round tall pot', 'round flat bin', 'round tall bin'], | |
['round plate', 'flat plate', 'square flat pot', 'square tall pot', 'round flat bin', 'round tall bin', 'square flat bin', 'square tall bin'], | |
['round plate', 'flat plate', 'round flat pot', 'square flat pot', 'round flat bin'], | |
['round tall pot', 'square tall pot', 'round flat bin', 'round tall bin'], | |
['round plate', 'round flat pot', 'square flat pot', 'round tall pot', 'square tall pot', 'round flat bin', 'round tall bin', 'square flat bin'], | |
] | |
def get_train_combos(train_envs): | |
train_combos = set() | |
for x in train_envs: | |
for c in itertools.combinations(x, 2): | |
train_combos.add(frozenset(c)) | |
print('train combos: {}'.format(len(train_combos))) | |
return train_combos | |
train_combos = get_train_combos(train_envs) | |
def get_eval_coverage(train_combos, eval_envs, verbose=True): | |
for x in eval_envs: | |
new = [] | |
for c in itertools.combinations(x, 2): | |
c = frozenset(c) | |
if c not in train_combos: | |
new.append(set(c)) | |
if verbose: | |
# print('new {}: {}'.format(len(new), new)) | |
pass | |
else: | |
print('new {}'.format(len(new))) | |
get_eval_coverage(train_combos, eval_envs) | |
print() | |
print('Random binomial assignment of items') | |
random.seed(0) | |
items = ['apple', 'banana', 'bottle', 'can', 'cup', 'drinking glass', 'fork', 'knife', 'pear', 'rolling pin', 'spatula', 'spoon'] | |
def sample_items(env, items, prob=0.3): | |
new = env[:] | |
for i in items: | |
if random.random() < prob: | |
new.append(i) | |
return new | |
new_train_envs = [sample_items(x, items) for x in train_envs] | |
new_eval_envs = [sample_items(x, items) for x in eval_envs] | |
new_train_combos = get_train_combos(new_train_envs) | |
get_eval_coverage(new_train_combos, new_eval_envs, verbose=False) | |
with open('scenes.csv', 'wt') as f: | |
for i, x in enumerate(new_train_envs): | |
f.write('train,{},'.format(i+1) + ','.join(x) + '\n') | |
for i, x in enumerate(new_eval_envs): | |
f.write('eval,{},'.format(i+1) + ','.join(x) + '\n') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment