Created
November 25, 2021 03:23
-
-
Save tamnguyenvan/f20f5021900d16fe6f5cd00216a1caf7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| import logging | |
| import pathlib | |
| import xml.etree.ElementTree as ET | |
| import cv2 | |
| import os | |
| class VOCDataset: | |
| def __init__(self, root, transform=None, target_transform=None, is_test=False, keep_difficult=False, label_file=None): | |
| """Dataset for VOC data. | |
| Args: | |
| root: the root of the VOC2007 or VOC2012 dataset, the directory contains the following sub-directories: | |
| Annotations, ImageSets, JPEGImages, SegmentationClass, SegmentationObject. | |
| """ | |
| self.root = pathlib.Path(root) | |
| self.transform = transform | |
| self.target_transform = target_transform | |
| if is_test: | |
| image_sets_file = self.root / "ImageSets/Main/test.txt" | |
| else: | |
| image_sets_file = self.root / "ImageSets/Main/trainval.txt" | |
| self.ids = VOCDataset._read_image_ids(image_sets_file) | |
| self.keep_difficult = keep_difficult | |
| # if the labels file exists, read in the class names | |
| label_file_name = self.root / "labels.txt" | |
| if os.path.isfile(label_file_name): | |
| class_string = "" | |
| with open(label_file_name, 'r') as infile: | |
| for line in infile: | |
| class_string += line.rstrip() | |
| # classes should be a comma separated list | |
| classes = class_string.split(',') | |
| # prepend BACKGROUND as first class | |
| classes.insert(0, 'BACKGROUND') | |
| classes = [ elem.replace(" ", "") for elem in classes] | |
| self.class_names = tuple(classes) | |
| logging.info("VOC Labels read from file: " + str(self.class_names)) | |
| else: | |
| logging.info("No labels file, using default VOC classes.") | |
| self.class_names = ('BACKGROUND', 'face') | |
| self.class_dict = {class_name: i for i, class_name in enumerate(self.class_names)} | |
| def __getitem__(self, index): | |
| image_id = self.ids[index] | |
| boxes, labels, is_difficult = self._get_annotation(image_id) | |
| if not self.keep_difficult: | |
| boxes = boxes[is_difficult == 0] | |
| labels = labels[is_difficult == 0] | |
| image = self._read_image(image_id) | |
| if self.transform: | |
| image, boxes, labels = self.transform(image, boxes, labels) | |
| if self.target_transform: | |
| boxes, labels = self.target_transform(boxes, labels) | |
| return image, boxes, labels | |
| def get_image(self, index): | |
| image_id = self.ids[index] | |
| image = self._read_image(image_id) | |
| if self.transform: | |
| image, _ = self.transform(image) | |
| return image | |
| def get_annotation(self, index): | |
| image_id = self.ids[index] | |
| return image_id, self._get_annotation(image_id) | |
| def __len__(self): | |
| return len(self.ids) | |
| @staticmethod | |
| def _read_image_ids(image_sets_file): | |
| ids = [] | |
| with open(image_sets_file) as f: | |
| for line in f: | |
| ids.append(line.rstrip()) | |
| return ids | |
| def _get_annotation(self, image_id): | |
| annotation_file = self.root / f"Annotations/{image_id}.xml" | |
| objects = ET.parse(annotation_file).findall("object") | |
| boxes = [] | |
| labels = [] | |
| is_difficult = [] | |
| for object in objects: | |
| class_name = object.find('name').text.lower().strip() | |
| # we're only concerned with clases in our list | |
| if class_name in self.class_dict: | |
| bbox = object.find('bndbox') | |
| # VOC dataset format follows Matlab, in which indexes start from 0 | |
| x1 = float(bbox.find('xmin').text) - 1 | |
| y1 = float(bbox.find('ymin').text) - 1 | |
| x2 = float(bbox.find('xmax').text) - 1 | |
| y2 = float(bbox.find('ymax').text) - 1 | |
| boxes.append([x1, y1, x2, y2]) | |
| labels.append(self.class_dict[class_name]) | |
| is_difficult_str = object.find('difficult').text | |
| is_difficult.append(int(is_difficult_str) if is_difficult_str else 0) | |
| return (np.array(boxes, dtype=np.float32), | |
| np.array(labels, dtype=np.int64), | |
| np.array(is_difficult, dtype=np.uint8)) | |
| def _read_image(self, image_id): | |
| image_file = self.root / f"JPEGImages/{image_id}.jpg" | |
| image = cv2.imread(str(image_file)) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| return image | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment