Created
May 13, 2020 09:49
-
-
Save sizhky/0108d3d0f684265ef2ecfe1e89445e9a to your computer and use it in GitHub Desktop.
Load VOC
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from snippets.loader import * | |
from PIL import Image | |
import xml.etree.ElementTree as ET | |
from torchvision import transforms | |
device = 'cuda' | |
voc_labels = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', | |
'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor') | |
label_map = {k: v + 1 for v, k in enumerate(voc_labels)} | |
label_map['background'] = 0 | |
rev_label_map = {v: k for k, v in label_map.items()} # Inverse mapping | |
def parse_annotation(annotation_path): | |
tree = ET.parse(annotation_path) | |
root = tree.getroot() | |
boxes = list() | |
labels = list() | |
difficulties = list() | |
for object in root.iter('object'): | |
difficult = int(object.find('difficult').text == '1') | |
label = object.find('name').text.lower().strip() | |
if label not in label_map: | |
continue | |
bbox = object.find('bndbox') | |
xmin = int(bbox.find('xmin').text) - 1 | |
ymin = int(bbox.find('ymin').text) - 1 | |
xmax = int(bbox.find('xmax').text) - 1 | |
ymax = int(bbox.find('ymax').text) - 1 | |
boxes.append([xmin, ymin, xmax, ymax]) | |
labels.append(label_map[label]) | |
difficulties.append(difficult) | |
return {'boxes': boxes, 'labels': labels, 'difficulties': difficulties} | |
def get_items(root, phase): | |
assert phase in {'train', 'val'} | |
with open(root/f'ImageSets/Main/{phase}.txt', 'r') as f: | |
_items = f.read().split('\n')[:-1] | |
items = [] | |
for item in _items: | |
im, annot = root/f'JPEGImages/{item}.jpg', root/f'Annotations/{item}.xml' | |
items.append((im, annot)) | |
return items | |
from imgaug import augmenters as iaa | |
from imgaug.augmentables.bbs import BoundingBox | |
aug_trn = iaa.Sequential([ | |
iaa.geometric.Affine(rotate=(-20,20), | |
translate_px=(-20,20), | |
shear=(-5,5), | |
mode='edge'), | |
iaa.Fliplr(0.5), | |
iaa.size.CropToSquare(), | |
iaa.size.Resize(300) | |
]) | |
aug_val = aug_trn = iaa.Sequential([ | |
iaa.size.Resize(300) | |
]) | |
def augment_image_with_bbs(image, bbs, aug_func): | |
bbs = [BoundingBox(*bb) for bb in bbs] | |
im, bbs = aug_func(image=image, bounding_boxes=bbs) | |
h, w = im.shape[:2] | |
bbs = [(bb.x1,bb.y1,bb.x2,bb.y2) for bb in bbs] | |
bbs = [[int(round(i)) for i in bb] for bb in bbs] | |
bbs = [(np.clip(x,0,w), np.clip(y,0,h), np.clip(X,0,w), np.clip(Y,0,h)) for x,y,X,Y in bbs] | |
return im, bbs | |
class VOCDataset(Dataset): | |
to_tensor = transforms.ToTensor() | |
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
def __init__(self, items, tfms=aug_val): | |
super(VOCDataset).__init__() | |
self.items = items | |
self.tfms = tfms | |
def __len__(self): return len(self.items) | |
def __getitem__(self, ix): | |
image_path, annot_path = self.items[ix] | |
image = Image.open(image_path).convert('RGB') | |
image = np.array(image) | |
annot = parse_annotation(annot_path) | |
bbs = annot['boxes'] | |
difficulties = annot['difficulties'] | |
if self.tfms is not None: | |
image, bbs = augment_image_with_bbs(image, bbs, self.tfms) | |
clss = [l for l in annot['labels']] | |
return Image.fromarray(image), bbs, clss, difficulties | |
def sample(self): | |
return choose(self) | |
def collate_fn(self, batch): | |
images = list() | |
boxes = list() | |
labels = list() | |
difficulties = list() | |
for _image, _boxes, _labels, _difficulties in batch: | |
_image = self.normalize(self.to_tensor(_image)) | |
_boxes = torch.FloatTensor(_boxes)/300. | |
_labels = torch.LongTensor(_labels) | |
_difficulties = torch.ByteTensor(_difficulties) | |
images.append(_image) | |
boxes.append(_boxes) | |
labels.append(_labels) | |
difficulties.append(_difficulties) | |
images = torch.stack(images, dim=0) | |
return images, boxes, labels, difficulties | |
if __name__ == '__main__': | |
from pathlib import Path | |
_2007_root = Path("/home/yyr/data/VOCdevkit/VOC2007") | |
_2012_root = Path("/home/yyr/data/VOCdevkit/VOC2012") | |
train_items = get_items(_2007_root, 'train') + get_items(_2012_root, 'train') | |
val_items = get_items(_2007_root, 'val') + get_items(_2012_root, 'val') | |
logger.info(f'\n{len(train_items)} training images\n{len(val_items)} validation images') | |
x = VOCDataset(train_items, tfms=seq) | |
np.random.seed(12) | |
im, bbs, clss = x.sample() | |
show(im, bbs=bbs, texts=map(lambda label:voc_labels[label-1], clss), sz=5, text_sz=10) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment