Created
August 17, 2017 15:15
-
-
Save Puzer/9daab9b41017254518d1f0ab3a69360f to your computer and use it in GitHub Desktop.
Perspective transformation for imgaug
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
seq = iaa.Sequential([ | |
Perspective(Normal(0, 0.03), Normal(0, 0.15)), # perspective transformation | |
]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cv2 | |
import numpy as np | |
import imgaug as ia | |
import six.moves as sm | |
from imgaug.augmenters import Augmenter | |
from imgaug.parameters import StochasticParameter | |
# https://en.wikipedia.org/wiki/3D_projection | |
def generate_map_matrix(x_projection, y_projection, img_height, img_width): | |
s = (img_height, img_width) | |
x_projection_top = 0 | |
x_projection_bottom = 0 | |
y_projection_left = 0 | |
y_projection_right = 0 | |
if y_projection > 0: | |
y_projection_right = -1 * y_projection * s[0] | |
elif y_projection < 0: | |
y_projection_left = y_projection * s[0] | |
if x_projection > 0: | |
x_projection_top = -1 * x_projection * s[1] | |
elif x_projection < 0: | |
x_projection_bottom = x_projection * s[1] | |
input_points = np.float32([[0, 0], | |
[s[1], 0], | |
[s[1], s[0]], | |
[0, s[0]]]) | |
output_points = np.float32([[x_projection_top, y_projection_right], | |
[s[1] - x_projection_top, y_projection_left], | |
[s[1] - x_projection_bottom, s[0] - y_projection_left], | |
[x_projection_bottom, s[0] - y_projection_right]]) | |
return cv2.getPerspectiveTransform(input_points, output_points) | |
class Perspective(Augmenter): | |
def __init__(self, x_projection, y_projection, name=None, deterministic=False, random_state=None): | |
super(Perspective, self).__init__(name=name, deterministic=deterministic, random_state=random_state) | |
assert (isinstance(x_projection, StochasticParameter)) | |
assert (isinstance(y_projection, StochasticParameter)) | |
self.x_projection = x_projection | |
self.y_projection = y_projection | |
def get_parameters(self): | |
return [self.x_projection, self.y_projection] | |
def _augment_images(self, images, random_state, parents, hooks): | |
nb_images = len(images) | |
result = images | |
(x_projection, y_projection), = self._draw_samples(nb_images, random_state) | |
for i in sm.xrange(nb_images): | |
height, width = images[i].shape[0], images[i].shape[1] | |
map_matrix = generate_map_matrix(x_projection[i], y_projection[i], height, width) | |
result[i] = cv2.warpPerspective(images[i], map_matrix, (width, height)) | |
return result | |
def _augment_keypoints(self, keypoints_on_images, random_state, parents, hooks): | |
result = [] | |
nb_images = len(keypoints_on_images) | |
(x_projection, y_projection), = self._draw_samples(nb_images, random_state) | |
for i, keypoints_on_image in enumerate(keypoints_on_images): | |
height, width = keypoints_on_image.height, keypoints_on_image.width | |
map_matrix = generate_map_matrix(x_projection[i], y_projection[i], height, width) | |
keypoints = np.array([keypoints_on_image.get_coords_array()], dtype='float32') | |
transformed_points = cv2.perspectiveTransform(keypoints, map_matrix)[0] | |
result.append(ia.KeypointsOnImage.from_coords_array(np.around(transformed_points).astype(np.int32), | |
shape=keypoints_on_image.shape)) | |
return result | |
def _draw_samples(self, nb_samples, random_state): | |
seed = random_state.randint(0, 10 ** 6, 1)[0] | |
x_projection_samples = self.x_projection.draw_samples((nb_samples,), | |
random_state=ia.new_random_state(seed + 30)) | |
y_projection_samples = self.y_projection.draw_samples((nb_samples,), | |
random_state=ia.new_random_state(seed + 40)) | |
projection_samples = (x_projection_samples, y_projection_samples) | |
return projection_samples, |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment