Created
October 13, 2021 08:52
-
-
Save asears/f1fa947680bd3b8aded6343148a5927b to your computer and use it in GitHub Desktop.
Kornia - convert image to tensor
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| %matplotlib inline | |
| import matplotlib.pyplot as plt | |
| import kornia as K | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| """https://kornia-tutorials.readthedocs.io/en/latest/geometric_transforms.html#bonus-backprop-to-the-future""" | |
| def imread(data_path: str) -> torch.Tensor: | |
| """Utility function that loads an image and converts to torch.""" | |
| # open image using OpenCV (HxWxC) | |
| img: np.ndarray = cv2.imread(data_path, cv2.IMREAD_COLOR) | |
| # cast image to torch tensor and convert to RGB | |
| img_t: torch.Tensor = K.utils.image_to_tensor(img, keepdim=False) # BxCxHxW | |
| img_t = K.color.bgr_to_rgb(img_t) | |
| return img_t.float() / 255. | |
| def imshow(image: np.ndarray, height: int, width: int): | |
| """Utility function to plot images.""" | |
| plt.figure(figsize=(height, width)) | |
| plt.imshow(image) | |
| plt.axis('off') | |
| plt.show() | |
| def draw_points(img_t: torch.Tensor, points: torch.Tensor) -> np.ndarray: | |
| """Utility function to draw a set of points in an image.""" | |
| # cast image to numpy (HxWxC) | |
| img: np.ndarray = K.utils.tensor_to_image(img_t) | |
| # using cv2.circle() method | |
| # draw a circle with blue line borders of thickness of 2 px | |
| img_out: np.ndarray = img.copy() | |
| for pt in points: | |
| x, y = int(pt[0]), int(pt[1]) | |
| img_out = cv2.circle( | |
| img_out, (x, y), radius=10, color=(0, 0, 255), thickness=5 | |
| ) | |
| return np.clip(img_out, 0, 1) | |
| # load original image | |
| img1: torch.Tensor = imread('img1.ppm') | |
| # generate N random points within the image | |
| N: int = 10 # the number of points | |
| B, CH, H, W = img1.shape | |
| points1: torch.Tensor = torch.rand(1, N, 2) | |
| points1[..., 0] *= W | |
| points1[..., 1] *= H | |
| # draw points and show | |
| img1_vis: np.ndarray = draw_points(img1[0], points1[0]) | |
| imshow(img1_vis, 10, 10) | |
| # declare an instance of our random affine generation eith `return_transform` | |
| # set to True, so that we recieve a tuple with the transformed image and the | |
| # transformation applied to the original image. | |
| transform: nn.Module = K.augmentation.RandomAffine( | |
| degrees=[-45., 45.], return_transform=True, p=1. | |
| ) | |
| # tranform image and retrieve transformation | |
| img2, trans = transform(img1) | |
| # transform the original points | |
| points2: torch.Tensor = K.geometry.transform_points(trans, points1) | |
| # visualize both images | |
| img2_vis: np.ndarray = draw_points(img2, points2[0]) | |
| img_vis = np.concatenate([img1_vis, img2_vis], axis=1) | |
| imshow(img_vis, 15, 15) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment