Last active
July 11, 2023 18:02
-
-
Save fgolemo/94b5caf0e209a6e71ab0ce2d75ad3ed8 to your computer and use it in GitHub Desktop.
3D rotation and reprojection in pytorch, i.e. differentiable
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import math | |
import torch | |
from torchvision import datasets | |
import cv2 # OpenCV, this is only used for visualization, see bottom of file | |
def rotation_matrix(axis, theta): | |
""" | |
Generalized 3d rotation via Euler-Rodriguez formula, https://www.wikiwand.com/en/Euler%E2%80%93Rodrigues_formula | |
Return the rotation matrix associated with counterclockwise rotation about | |
the given axis by theta radians. | |
""" | |
axis = axis / torch.sqrt(torch.dot(axis, axis)) | |
a = torch.cos(theta / 2.0) | |
b, c, d = -axis * torch.sin(theta / 2.0) | |
aa, bb, cc, dd = a * a, b * b, c * c, d * d | |
bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d | |
return torch.tensor([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], | |
[2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], | |
[2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]]) | |
ds = datasets.MNIST('../data', train=True, download=True) | |
# print (len(ds)) | |
digit = np.array(ds[2][0]) | |
axis = [0,0,1] | |
print (digit.shape) | |
# the following does a manual homography transformation by calculating a flow grid | |
# (list of pixel coordinates, then apply the rotation matrix to each pixel coordinate) | |
# generate grid of pixel coordinates, one pair of (x,y,1) for each pixel | |
x_pos = np.arange(0,digit.shape[0]).astype(np.float32) | |
y_pos = np.arange(0,digit.shape[1]).astype(np.float32) | |
print (x_pos[:3],x_pos[-3:]) | |
print (y_pos[:3],y_pos[-3:]) | |
# center pixel coords | |
x_pos -= ((digit.shape[0]-1)/2) | |
y_pos -= ((digit.shape[1]-1)/2) | |
print (x_pos[:3],x_pos[-3:]) | |
print (y_pos[:3],y_pos[-3:]) | |
# normalize pixel coords | |
x_pos /= ((digit.shape[0]-1)/2) | |
y_pos /= ((digit.shape[1]-1)/2) | |
print (x_pos[:3],x_pos[-3:]) | |
print (y_pos[:3],y_pos[-3:]) | |
# one entry for every pixel in (x,y,1) format | |
xs, ys = np.meshgrid(x_pos, y_pos) | |
print (xs.shape, xs[:3,:3]) | |
print (ys.shape, ys[:3,:3]) | |
# now merge them all into a list of pixel coords | |
coordinates = np.stack((xs.flatten(), ys.flatten(), np.ones(len(x_pos)*len(y_pos))), axis=1) | |
print (coordinates.shape, coordinates[:3]) | |
# this variable if fixed for all images, the list of initial pixel locations | |
coordinates = torch.from_numpy(coordinates).float() | |
axes = [ | |
[1,0,0], | |
[0,1,0], | |
[0,0,1] | |
] | |
digit = torch.from_numpy(digit).unsqueeze(0).unsqueeze(0).float() | |
print (digit.size()) | |
for axis in axes: | |
axis = torch.tensor(axis).float() | |
for theta in np.linspace(-90,90,100): | |
# manual deg 2 rad | |
theta = torch.tensor([theta * np.pi / 180]) | |
rot = rotation_matrix(axis, theta) | |
# apply rotation to each pixel coordinate | |
new_locations = torch.matmul(coordinates, rot) | |
new_locations[:,0] /= new_locations[:,2] # apply z for perspective instead of orthographic cam | |
new_locations[:,1] /= new_locations[:,2] # apply z for perspective instead of orthographic cam | |
# cut off the z axis and reshape | |
flow_grid = new_locations[:,:2].reshape((1,28,28,2)) # the 2 at the end are the source pixel coords | |
# sample original image with this flow field, i.e. reproject | |
out = torch.nn.functional.grid_sample(digit,flow_grid).squeeze(0).squeeze(0).numpy() | |
# display the image | |
cv2.imshow("ladeeda",out) | |
cv2.waitKey(50) # this is required, otherwise the opencv window will close immediately |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
nice work