Created
April 11, 2018 12:38
-
-
Save jinyu121/7a7d1c281c2aab4da83914c81ca104f5 to your computer and use it in GitHub Desktop.
Multi MNIST in VOC format
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torchvision | |
| import torchvision.transforms as transforms | |
| from torch.utils.data import DataLoader | |
| import numpy as np | |
| import os | |
| from tqdm import tqdm, trange | |
| from skimage.io import imsave | |
| from easydict import EasyDict | |
| from copy import deepcopy | |
| import xmltodict | |
| import json | |
| from skimage.transform import rescale | |
| config = EasyDict({ | |
| "dir_data": "data", | |
| "dir_output": "VOCMNIST", | |
| "image_size": 64, | |
| "num_obj": 3, | |
| "grids": 2, | |
| }) | |
| dirs = EasyDict({ | |
| "image": os.path.join(config['dir_output'], "JPEGImages"), | |
| "annotation": os.path.join(config['dir_output'], "Annotations"), | |
| "set": os.path.join(config['dir_output'], "ImageSets", 'Main') | |
| }) | |
| base_annotation_file = { | |
| "annotation": { | |
| "folder": "VOCMNIST", | |
| "filename": "", | |
| "source": { | |
| "database": "The VOCMNIST Dataset", | |
| "annotation": "VOCMNIST" | |
| }, | |
| "size": { | |
| "width": config.image_size, | |
| "height": config.image_size, | |
| "depth": 1 | |
| }, | |
| "segmented": "0", | |
| "object": [] | |
| } | |
| } | |
| base_annotation_object = { | |
| "name": "", | |
| "pose": "Left", | |
| "truncated": "1", | |
| "difficult": "0", | |
| "bndbox": { | |
| "xmin": 0, | |
| "ymin": 0, | |
| "xmax": 0, | |
| "ymax": 0 | |
| } | |
| } | |
| transform = transforms.Compose([ | |
| transforms.RandomRotation(30), | |
| transforms.ToTensor() | |
| ]) | |
| dataset = { | |
| "train": torchvision.datasets.MNIST(root=config.dir_data, train=True, download=True, transform=transform), | |
| "test": torchvision.datasets.MNIST(root=config.dir_data, train=False, download=True, transform=transform) | |
| } | |
| loader = { | |
| "train": DataLoader(dataset['train'], batch_size=1, shuffle=True, num_workers=2), | |
| "test": DataLoader(dataset['test'], batch_size=1, shuffle=False, num_workers=2) | |
| } | |
| os.makedirs(config['dir_output'], exist_ok=True) | |
| for k, v in dirs.items(): | |
| os.makedirs(v, exist_ok=True) | |
| ith = 0 | |
| def make(data, load, setname): | |
| global ith | |
| try: | |
| with open(setname, 'w') as sf: | |
| for _ in trange(len(data)): | |
| ith += 1 | |
| filename = "{:0>5}".format(ith) | |
| filename_image = "{}.png".format(filename) | |
| filename_annotation = "{}.xml".format(filename) | |
| annotation = deepcopy(base_annotation_file) | |
| annotation['annotation']['filename'] = filename_image | |
| image = np.zeros([config.image_size, config.image_size], dtype=np.float) | |
| grid_index = np.random.permutation(config.grids * config.grids) | |
| for ith_img in range(np.random.randint(config.num_obj) + 1): | |
| img, label = next(iter(load)) | |
| im = img[0, 0, :].numpy() | |
| im = rescale(im, np.random.uniform(0.5, 1.2), mode='reflect') | |
| offset_x = np.random.randint(max(1, config.image_size // config.grids - im.shape[0])) | |
| offset_y = np.random.randint(max(1, config.image_size // config.grids - im.shape[1])) | |
| x0 = (config.image_size // config.grids) * (grid_index[ith_img] // config.grids) + offset_x | |
| y0 = (config.image_size // config.grids) * (grid_index[ith_img] % config.grids) + offset_y | |
| w = min(im.shape[0], config.image_size - x0) | |
| h = min(im.shape[1], config.image_size - y0) | |
| x1 = x0 + w | |
| y1 = y0 + h | |
| image[x0:x1, y0:y1] += im[:w, :h] | |
| anno = deepcopy(base_annotation_object) | |
| anno['name'] = str(label[0]) | |
| anno['bndbox']['xmin'] = x0 | |
| anno['bndbox']['ymin'] = y0 | |
| anno['bndbox']['xmax'] = x1 | |
| anno['bndbox']['ymax'] = y1 | |
| annotation["annotation"]['object'].append(anno) | |
| image = np.clip(image, 0, 1) | |
| imsave(os.path.join(dirs.image, filename_image), image) | |
| with open(os.path.join(dirs.annotation, filename_annotation), 'w') as f: | |
| xmltodict.unparse(annotation, f, full_document=False, pretty=True) | |
| sf.write(filename + "\n") | |
| except Exception as e: | |
| print(e) | |
| ith -= 1 | |
| for x in ['train', 'test']: | |
| make(dataset[x], loader[x], os.path.join(dirs.set, '{}.txt'.format(x))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment