Created
August 1, 2019 16:59
-
-
Save gilr00y/a263e447dedeee6f9a108a09b28e07f4 to your computer and use it in GitHub Desktop.
Torch - Sliding Window transform for Object Detection
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
class WindowUtils: | |
@staticmethod | |
def is_overlapping(window, annot): | |
crop_tolerance = 0.25 | |
win_min_x, win_min_y, win_max_x, win_max_y = window | |
ann_min_x, ann_min_y, ann_max_x, ann_max_y, cat = annot | |
orig_ann_area = (ann_max_x - ann_min_x) * (ann_max_y - ann_min_y) | |
# Only return True if overlapping_area > crop_tolerance * original_area | |
min_crop_area = crop_tolerance * orig_ann_area | |
if (ann_min_x < win_max_x and | |
ann_max_x > win_min_x and | |
ann_min_y < win_max_y and | |
ann_max_y > win_min_y and | |
( | |
(np.min([ann_max_x, win_max_x]) - np.max([ann_min_x, win_min_x])) * | |
(np.min([ann_max_y, win_max_y]) - np.max([ann_min_y, win_min_y])) >= min_crop_area | |
) | |
): | |
return True | |
else: | |
return False | |
@staticmethod | |
def calculate_sub_img_annot(window, ann): | |
win_min_x, win_min_y, win_max_x, win_max_y = window | |
ann_min_x, ann_min_y, ann_max_x, ann_max_y, cat = ann | |
sub_img_width = win_max_x - win_min_x | |
sub_img_height = win_max_y - win_min_y | |
sub_min_x = np.max([ann_min_x - win_min_x, 0.]) | |
sub_max_x = np.min([ann_max_x - win_min_x, sub_img_width - 1]) | |
sub_min_y = np.max([ann_min_y - win_min_y, 0.]) | |
sub_max_y = np.min([ann_max_y - win_min_y, sub_img_height - 1]) | |
return [sub_min_x, sub_min_y, sub_max_x, sub_max_y, cat] | |
@staticmethod | |
def get_annotations_for_window(window, annots): | |
sub_img_annots = [] | |
for ann in annots: | |
if WindowUtils.is_overlapping(window, ann): | |
sub_img_annots.append(WindowUtils.calculate_sub_img_annot(window, ann)) | |
return sub_img_annots | |
@staticmethod | |
def chip_img(img, window): | |
return img[window[1]:window[3], window[0]:window[2]] | |
class Windower(object): | |
"""Returns array of samples according to number of windows.""" | |
def __call__(self, sample, win_height=800, win_width=800, min_overlap=200): | |
windows = [] | |
image, annots = sample['img'], sample['annot'] | |
rows, cols, cns = image.shape | |
if rows <= win_height: | |
num_vertical_windows = 1 | |
else: | |
num_vertical_windows = int(np.ceil(rows / (win_height - min_overlap))) # so a 900px image would have 2 vertical windows | |
if cols <= win_width: | |
num_horizontal_windows = 1 | |
else: | |
num_horizontal_windows = int(np.ceil(cols / (win_width - min_overlap))) | |
# Generate x, y coords for each window | |
for h_win_idx in range(num_horizontal_windows): | |
for v_win_idx in range(num_vertical_windows): | |
if h_win_idx + 1 == num_horizontal_windows: | |
# Last horizontal window, so hug right side of image | |
min_x = cols - win_width | |
max_x = min_x + win_width - 1 | |
else: | |
min_x = h_win_idx * (win_width - min_overlap) | |
max_x = min_x + win_width - 1 | |
if v_win_idx + 1 == num_vertical_windows: | |
# Last vertical window, so hug bottom of image | |
min_y = rows - win_height | |
max_y = min_y + win_height - 1 | |
else: | |
min_y = v_win_idx * (win_height - min_overlap) | |
max_y = min_y + win_height - 1 | |
# Format (min_x, min_y, max_x, max_y) from top-left | |
windows.append((min_x, min_y, max_x, max_y)) | |
# For debugging | |
# def plot_sub_img(image, window, annots): | |
# fig,ax = plt.subplots(1,figsize=(30,30)) | |
# ax.imshow(WindowUtils.chip_img(image, window)) | |
# for an in annots: | |
# ax.add_patch( | |
# patches.Rectangle( | |
# (an[0],an[1]), | |
# width=an[2]-an[0], | |
# height=an[3]-an[1], | |
# linewidth=5, | |
# edgecolor=np.random.choice(['b', 'g']), | |
# facecolor='none')) | |
# for win_idx, win in enumerate(windows): | |
# win_annots = WindowUtils.get_annotations_for_window(win, annots) | |
# if len(win_annots): | |
# # print('OVERLAP FOR WINDOW {}'.format(win_idx)) | |
# plot_sub_img(image, win, win_annots) | |
ret = [{ | |
'img': torch.from_numpy(WindowUtils.chip_img(image, window)), | |
'annot': torch.from_numpy(np.array(WindowUtils.get_annotations_for_window(window, annots))), | |
'sub_img_px_coords': window # To re-anchor detections in large image. | |
} for window in windows] | |
return ret |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment