Created
March 5, 2019 20:05
-
-
Save crazysal/5732e71119205de892bc4c94cfa0e2ce to your computer and use it in GitHub Desktop.
Access Math Voc Triplet Data Loader
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import traceback | |
import time | |
import xml.etree.ElementTree as ET | |
import torch | |
from torch.utils.data import Dataset | |
from data.utils import filter_by_shape, extend_bbox, image_channel_mean, bbox_to_quad | |
from data.transforms import TransformCrop, TransformPad, BinarizerWeightGeneration | |
import nnet.cfg as config | |
import cv2 | |
import numpy as np | |
class AccessMathVOCTriplets(Dataset): | |
TrainValLectures = ["lecture_01", "lecture_06", "lecture_18", "NM_lecture_01", "NM_lecture_03"] | |
TestLectures = ["lecture_02", "lecture_07", "lecture_08", "lecture_10", "lecture_15", "NM_lecture_02", | |
"NM_lecture_05"] | |
def __init__(self, accessmathVOC_root='/run1/dataset-txt/AccessMathVOC/', data_transform=None, target_transform=None, lectures=None, | |
gt_type="bbox", imageset="trainval", seed=5, imageset_dir='Main'): | |
self.accessmathVOC_root = accessmathVOC_root | |
self.data_transform = data_transform | |
self.target_transform = target_transform | |
self.gt_type = gt_type | |
valid_imagesets = ["trainval", "train", "val", "test"] | |
self.imageset = "trainval" if imageset.lower() not in valid_imagesets else imageset.lower() | |
self.imageset_dir = imageset_dir | |
self.seed = seed | |
if self.imageset == "trainval": | |
self.lectures = AccessMathVOCTriplets.TrainValLectures | |
elif self.imageset in ["train", "val"]: | |
self.lectures = lectures if lectures is not None else AccessMathVOCTriplets.TrainValLectures | |
else: | |
self.lectures = lectures if lectures is not None else AccessMathVOCTriplets.TestLectures | |
print('finding unique regions...') | |
self.region_frame_map = {} | |
max_height = 0 | |
max_width = 0 | |
max_area = 0 | |
for lecture in self.lectures: | |
imageset_ = "trainval" if self.imageset in ["train", "val", "trainval"] else "test" | |
imageset_f = os.path.join(self.accessmathVOC_root, lecture, | |
"ImageSets", self.imageset_dir, "{}.txt".format(imageset_)) | |
frame_ids = self._read_imageset_f(imageset_f) | |
frame_ids = sorted(frame_ids, key=lambda frame_id: int(frame_id)) | |
for frame_id in frame_ids: | |
anno_path = self._id_to_anno_path(lecture, frame_id) | |
bboxes, region_ids = self._anno_from_xml(os.path.join(self.accessmathVOC_root, anno_path)) | |
heights = bboxes[:, 3] - bboxes[:, 1] | |
widths = bboxes[:, 2] - bboxes[:, 0] | |
areas = heights * widths | |
max_height = heights.max() if heights.max() >= max_height else max_height | |
max_width = widths.max() if widths.max() >= max_width else max_width | |
max_area = areas.max() if areas.max() >= max_area else max_area | |
for i, region_id in enumerate(region_ids): | |
if (lecture, region_id) in self.region_frame_map: | |
self.region_frame_map[(lecture, region_id)] += [(frame_id, tuple(bboxes[i, :]))] | |
else: | |
self.region_frame_map[(lecture, region_id)] = [(frame_id, tuple(bboxes[i, :]))] | |
self.unique_regions = list(self.region_frame_map.keys()) | |
print('... found') | |
print('max region height:', max_height, 'width:', max_width, 'area:', max_area) | |
def __len__(self): | |
return len(self.unique_regions) | |
def __getitem__(self, item): | |
# pick a lecture and region_id | |
lecture, region_id = self.unique_regions[item] | |
# corner case: what if the box is somehow invalid (use filter_by_shape and rerun sampling) | |
all_bboxes = np.empty(shape=(0, 4)) | |
while all_bboxes.shape[0] != 3: | |
anchor_frame, anchor_bbox, pos_frame, pos_bbox = self._get_positive_samples(lecture, region_id) | |
neg_frame, neg_bbox = self._get_negative_sample(lecture, region_id, anchor_frame, anchor_bbox) | |
all_bboxes = filter_by_shape(np.asarray([pos_bbox, anchor_bbox, neg_bbox])) | |
anchor_gt_path = self._id_to_anno_path(lecture, anchor_frame) | |
anchor_gt, _ = self._anno_from_xml(os.path.join(self.accessmathVOC_root, anchor_gt_path)) | |
# get the image for all the frames | |
anchor_img = self._get_img(lecture, anchor_frame) | |
pos_img = self._get_img(lecture, pos_frame) | |
neg_img = self._get_img(lecture, neg_frame) | |
# convert the bbox tuples to np.array | |
anchor_bbox = np.asarray(anchor_bbox).reshape(1, -1) | |
pos_bbox = np.asarray(pos_bbox).reshape(1, -1) | |
neg_bbox = np.asarray(neg_bbox).reshape(1, -1) | |
triplet_sample = {"anchor_image": anchor_img, "anchor_bbox": anchor_bbox, "anchor_gt": anchor_gt, | |
"pos_image": pos_img, "pos_bbox": pos_bbox, | |
"neg_image": neg_img, "neg_bbox": neg_bbox | |
} | |
if self.gt_type == "quad": | |
all_bboxes = np.concatenate([anchor_bbox, pos_bbox, neg_bbox], axis=0) | |
all_quads = bbox_to_quad(all_bboxes) | |
triplet_sample["anchor_bbox"] = all_quads[0, :].reshape(1, -1) | |
triplet_sample["pos_bbox"] = all_quads[1, :].reshape(1, -1) | |
triplet_sample["neg_bbox"] = all_quads[2, :].reshape(1, -1) | |
if self.data_transform is not None: | |
triplet_sample = self.data_transform(triplet_sample) | |
if self.target_transform is not None: | |
triplet_sample = self.target_transform(triplet_sample) | |
return triplet_sample | |
@staticmethod | |
def _read_imageset_f(fpath): | |
with open(fpath, 'r') as f: | |
lines = f.readlines() | |
lines = [l.strip() for l in lines if len(l.strip()) > 0] | |
return lines | |
@staticmethod | |
def _id_to_image_path(lecture, _id): | |
return os.path.join(lecture, "JPEGImages", "{}.jpg".format(_id)) | |
@staticmethod | |
def _id_to_binary_image_path(lecture, _id): | |
return os.path.join(lecture, "binary", "{}.png".format(_id)) | |
@staticmethod | |
def _id_to_anno_path(lecture, _id): | |
return os.path.join(lecture, "Annotations", "{}.xml".format(_id)) | |
@staticmethod | |
def _anno_from_xml(fpath): | |
root = ET.parse(fpath).getroot() | |
bboxes = [] | |
ids = [] | |
for obj in root.iter('object'): | |
ids += [obj.find('ID').text] | |
bndbox = obj.find('bndbox') | |
bbox = [] | |
for pt in ["xmin", "ymin", "xmax", "ymax"]: | |
coord = int(bndbox.find(pt).text) | |
bbox += [coord] | |
bboxes += [bbox] | |
return np.asarray(bboxes), ids | |
def _get_positive_samples(self, lecture, region_id): | |
positive_samples = self.region_frame_map[(lecture, region_id)] | |
# corner case: what if some region has only one frame of occurence | |
replacement = len(positive_samples) == 1 | |
s1, s2 = np.random.choice(len(positive_samples), 2, replace=replacement) | |
anchor_frame, anchor_bbox = positive_samples[s1] | |
pos_frame, pos_bbox = positive_samples[s2] | |
return anchor_frame, anchor_bbox, pos_frame, pos_bbox | |
def _get_negative_sample(self, lecture, region_id, anchor_frame_id, anchor_bbox): | |
# pick a region with diff id from same lecture to generate neg sample s.t. it is within some distance of anchor | |
def bbox_distance(bbox1, bbox2): | |
bb1 = np.asarray(bbox1, dtype='float') | |
bb1[::2] /= 1920. | |
bb1[1::2] /= 1080 | |
bb2 = np.asarray(bbox2, dtype='float') | |
bb2[::2] /= 1920. | |
bb2[1::2] /= 1080. | |
return np.linalg.norm(bb1 - bb2) | |
negative_regions = {k: v for k, v in self.region_frame_map.items() if k[0] == lecture and k[1] != region_id} | |
valid_negative_regions = {} | |
th = 0.33 | |
while len(valid_negative_regions) < 1: | |
for k, v in negative_regions.items(): | |
for (frame_id, bbox) in v: | |
if bbox_distance(bbox, anchor_bbox) <= th: | |
if k in valid_negative_regions: | |
valid_negative_regions[k] += [(frame_id, bbox)] | |
else: | |
valid_negative_regions[k] = [(frame_id, bbox)] | |
th *= 2. | |
# pick a random negative region_id from valid_negative_region_ids | |
s3 = np.random.randint(0, len(valid_negative_regions)) | |
# get list of frame ids and bboxes for the random negative region_id | |
negative_samples = list(valid_negative_regions.values())[s3] | |
# pick a random frame_id and bbox from negative samples | |
s4 = np.random.randint(0, len(negative_samples)) | |
neg_frame, neg_bbox = negative_samples[s4] | |
return neg_frame, neg_bbox | |
def _get_img(self, lecture, frame_id): | |
img_path = self._id_to_image_path(lecture, frame_id) | |
return cv2.imread(os.path.join(self.accessmathVOC_root, img_path)) | |
@staticmethod | |
def detection_collate(batch): | |
#print("in collate") | |
#print(np.shape(batch), type(batch[0]["anchor_image"]), np.shape(batch[0]["anchor_image"])) | |
anchor_images = [] | |
pos_images = [] | |
neg_images = [] | |
anchor_bboxes = [] | |
pos_bboxes = [] | |
neg_bboxes = [] | |
for sample in batch: | |
anchor_images += [sample["anchor_image"][np.newaxis, :, :, :]] | |
pos_images += [sample["pos_image"][np.newaxis, :, :, :]] | |
neg_images += [sample["neg_image"][np.newaxis, :, :, :]] | |
anchor_bboxes += [sample["anchor_bbox"][np.newaxis, :, :]] | |
pos_bboxes += [sample["pos_bbox"][np.newaxis, :, :]] | |
neg_bboxes += [sample["neg_bbox"][np.newaxis, :, :]] | |
stacked_images = np.concatenate(anchor_images + pos_images + neg_images, axis=0) | |
stacked_bboxes = np.concatenate(anchor_bboxes + pos_bboxes + neg_bboxes, axis=0) | |
images = torch.from_numpy(stacked_images).permute(0, 3, 1, 2) | |
bboxes = torch.from_numpy(stacked_bboxes) | |
return images, bboxes | |
if __name__ == "__main__": | |
""" | |
Example of creation of dataset object with train and val | |
""" | |
from torch.utils.data import DataLoader | |
from data.transforms import AMVOCTripletTransform | |
from nnet.layers.text_align_bbox import TextAlign | |
lectures = AccessMathVOCTriplets.TrainValLectures | |
v = np.random.randint(0, len(lectures)) | |
train_lectures = lectures[: v] + lectures[v :] | |
val_lecture = [lectures[v]] | |
amvoc_triplets_train = AccessMathVOCTriplets('/run1/dataset-txt/AccessMathVOC', AMVOCTripletTransform(), | |
None, train_lectures, imageset='train') | |
amvoc_triplets_val = AccessMathVOCTriplets('/run1/dataset-txt/AccessMathVOC', AMVOCTripletTransform(), | |
None, val_lecture, imageset='val') | |
resample = TextAlign((1, 3, 30, 30), device='cpu', pool_type=0, rescale=1.0) | |
dataloader = DataLoader(amvoc_triplets_train, batch_size=5, collate_fn=AccessMathVOCTriplets.detection_collate) | |
dataiter = iter(dataloader) | |
for i, (images, boxes) in enumerate(dataiter): | |
print(images.shape, boxes.shape) | |
resampled = resample(images.float(), boxes.float()) | |
print(resampled[0].shape, resampled[1]) | |
input() | |
if i == 2: | |
break | |
test_im = torch.tensor(np.random.random(size=(4, 3, 60, 60))) | |
n_rois = [5, 9, 2, 4] | |
all_test_coords = [] | |
for n in n_rois: | |
test_coords = torch.tensor(np.random.random(size=(n, 4))) | |
test_coords[:, 2] = test_coords[:, 0] + 1. | |
test_coords[:, 3] = test_coords[:, 1] + 1. | |
test_coords *= 30. | |
all_test_coords += [test_coords.float()] | |
resampled2 = resample(test_im.float(), all_test_coords) | |
print(resampled2[0].shape, resampled2[1]) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment