Skip to content

Instantly share code, notes, and snippets.

@jinyu121
Created April 11, 2018 12:38
Show Gist options
  • Select an option

  • Save jinyu121/7a7d1c281c2aab4da83914c81ca104f5 to your computer and use it in GitHub Desktop.

Select an option

Save jinyu121/7a7d1c281c2aab4da83914c81ca104f5 to your computer and use it in GitHub Desktop.
Multi MNIST in VOC format
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import os
from tqdm import tqdm, trange
from skimage.io import imsave
from easydict import EasyDict
from copy import deepcopy
import xmltodict
import json
from skimage.transform import rescale
config = EasyDict({
"dir_data": "data",
"dir_output": "VOCMNIST",
"image_size": 64,
"num_obj": 3,
"grids": 2,
})
dirs = EasyDict({
"image": os.path.join(config['dir_output'], "JPEGImages"),
"annotation": os.path.join(config['dir_output'], "Annotations"),
"set": os.path.join(config['dir_output'], "ImageSets", 'Main')
})
base_annotation_file = {
"annotation": {
"folder": "VOCMNIST",
"filename": "",
"source": {
"database": "The VOCMNIST Dataset",
"annotation": "VOCMNIST"
},
"size": {
"width": config.image_size,
"height": config.image_size,
"depth": 1
},
"segmented": "0",
"object": []
}
}
base_annotation_object = {
"name": "",
"pose": "Left",
"truncated": "1",
"difficult": "0",
"bndbox": {
"xmin": 0,
"ymin": 0,
"xmax": 0,
"ymax": 0
}
}
transform = transforms.Compose([
transforms.RandomRotation(30),
transforms.ToTensor()
])
dataset = {
"train": torchvision.datasets.MNIST(root=config.dir_data, train=True, download=True, transform=transform),
"test": torchvision.datasets.MNIST(root=config.dir_data, train=False, download=True, transform=transform)
}
loader = {
"train": DataLoader(dataset['train'], batch_size=1, shuffle=True, num_workers=2),
"test": DataLoader(dataset['test'], batch_size=1, shuffle=False, num_workers=2)
}
os.makedirs(config['dir_output'], exist_ok=True)
for k, v in dirs.items():
os.makedirs(v, exist_ok=True)
ith = 0
def make(data, load, setname):
global ith
try:
with open(setname, 'w') as sf:
for _ in trange(len(data)):
ith += 1
filename = "{:0>5}".format(ith)
filename_image = "{}.png".format(filename)
filename_annotation = "{}.xml".format(filename)
annotation = deepcopy(base_annotation_file)
annotation['annotation']['filename'] = filename_image
image = np.zeros([config.image_size, config.image_size], dtype=np.float)
grid_index = np.random.permutation(config.grids * config.grids)
for ith_img in range(np.random.randint(config.num_obj) + 1):
img, label = next(iter(load))
im = img[0, 0, :].numpy()
im = rescale(im, np.random.uniform(0.5, 1.2), mode='reflect')
offset_x = np.random.randint(max(1, config.image_size // config.grids - im.shape[0]))
offset_y = np.random.randint(max(1, config.image_size // config.grids - im.shape[1]))
x0 = (config.image_size // config.grids) * (grid_index[ith_img] // config.grids) + offset_x
y0 = (config.image_size // config.grids) * (grid_index[ith_img] % config.grids) + offset_y
w = min(im.shape[0], config.image_size - x0)
h = min(im.shape[1], config.image_size - y0)
x1 = x0 + w
y1 = y0 + h
image[x0:x1, y0:y1] += im[:w, :h]
anno = deepcopy(base_annotation_object)
anno['name'] = str(label[0])
anno['bndbox']['xmin'] = x0
anno['bndbox']['ymin'] = y0
anno['bndbox']['xmax'] = x1
anno['bndbox']['ymax'] = y1
annotation["annotation"]['object'].append(anno)
image = np.clip(image, 0, 1)
imsave(os.path.join(dirs.image, filename_image), image)
with open(os.path.join(dirs.annotation, filename_annotation), 'w') as f:
xmltodict.unparse(annotation, f, full_document=False, pretty=True)
sf.write(filename + "\n")
except Exception as e:
print(e)
ith -= 1
for x in ['train', 'test']:
make(dataset[x], loader[x], os.path.join(dirs.set, '{}.txt'.format(x)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment