Last active
February 5, 2023 20:21
-
-
Save nilesh0109/f98eed779844c6b570740d5ef78868a3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
import torch | |
import torch.nn.functional as F | |
import torchvision.transforms as transforms | |
## Fisheye Transformation | |
def get_of_fisheye(height, width, center, magnitude): | |
xx, yy = torch.linspace(-1, 1, width), torch.linspace(-1, 1, height) | |
gridy, gridx = torch.meshgrid(yy, xx) #create identity grid | |
grid = torch.stack([gridx, gridy], dim=-1) | |
d = center - grid #calculate the distance(cx - x, cy - y) | |
d_sum = torch.sqrt((d**2).sum(axis=-1)) # sqrt((cx-x)**2 + (cy-y)**2) | |
grid += d * d_sum.unsqueeze(-1) * magnitude #calculate dx & dy and add to original values | |
return grid.unsqueeze(0) #unsqueeze(0) since the grid needs to be 4D. | |
## Horizontal Wave Transformation | |
def get_of_horizontalwave(height, width, freq, amplitude): | |
xx, yy = torch.linspace(-1, 1, width), torch.linspace(-1, 1, height) | |
gridy, gridx = torch.meshgrid(yy, xx) #create identity grid | |
grid = torch.stack([gridx, gridy], dim=-1) | |
dy = amplitude * torch.cos(freq * grid[:,:,0]) #calculate dy | |
grid[:,:,1] += dy | |
return grid.unsqueeze(0) #unsqueeze(0) since the grid needs to be 4D. | |
## UTILITY FUNCTIONS | |
## Create Image Batch | |
def get_image_batch(img): | |
transform = transforms.Compose([transforms.ToTensor()]) | |
tfms_img = transform(img) | |
imgs = torch.unsqueeze(tfms_img, dim=0) | |
return imgs | |
def plot(img, fisheye_output, hwave_output): | |
fisheye_out = fisheye_output[0].numpy() | |
fisheye_out = np.moveaxis(fisheye_out, 0,-1) | |
hwave_out = hwave_output[0].numpy() | |
hwave_out = np.moveaxis(hwave_out, 0,-1) | |
fig, ax = plt.subplots(1,3, figsize=(16,4)) | |
ax[0].imshow(img) | |
ax[1].imshow(fisheye_out) | |
ax[2].imshow(hwave_out) | |
ax[0].set_title('Input Image(Checkerboard)') | |
ax[1].set_title('Fisheye') | |
ax[2].set_title('Horizontal Wave Tfms') | |
plt.show() | |
img = Image.open('checkerboard.png') | |
imgs = get_image_batch(img) | |
N, C, H, W = imgs.shape | |
fisheye_grid = get_of_fisheye(H, W, torch.tensor([0,0]), 0.4) | |
hwave_grid = get_of_horizontalwave(H, W, 10, 0.1) | |
fisheye_output = F.grid_sample(imgs, fisheye_grid, align_corners=True) | |
hwave_output = F.grid_sample(imgs, hwave_grid, align_corners=True) | |
plot(img, fisheye_output, hwave_output) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment