Skip to content

Instantly share code, notes, and snippets.

@fgolemo
Last active July 11, 2023 18:02
Show Gist options
  • Save fgolemo/94b5caf0e209a6e71ab0ce2d75ad3ed8 to your computer and use it in GitHub Desktop.
Save fgolemo/94b5caf0e209a6e71ab0ce2d75ad3ed8 to your computer and use it in GitHub Desktop.
3D rotation and reprojection in pytorch, i.e. differentiable
import numpy as np
import math
import torch
from torchvision import datasets
import cv2 # OpenCV, this is only used for visualization, see bottom of file
def rotation_matrix(axis, theta):
"""
Generalized 3d rotation via Euler-Rodriguez formula, https://www.wikiwand.com/en/Euler%E2%80%93Rodrigues_formula
Return the rotation matrix associated with counterclockwise rotation about
the given axis by theta radians.
"""
axis = axis / torch.sqrt(torch.dot(axis, axis))
a = torch.cos(theta / 2.0)
b, c, d = -axis * torch.sin(theta / 2.0)
aa, bb, cc, dd = a * a, b * b, c * c, d * d
bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d
return torch.tensor([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)],
[2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)],
[2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]])
ds = datasets.MNIST('../data', train=True, download=True)
# print (len(ds))
digit = np.array(ds[2][0])
axis = [0,0,1]
print (digit.shape)
# the following does a manual homography transformation by calculating a flow grid
# (list of pixel coordinates, then apply the rotation matrix to each pixel coordinate)
# generate grid of pixel coordinates, one pair of (x,y,1) for each pixel
x_pos = np.arange(0,digit.shape[0]).astype(np.float32)
y_pos = np.arange(0,digit.shape[1]).astype(np.float32)
print (x_pos[:3],x_pos[-3:])
print (y_pos[:3],y_pos[-3:])
# center pixel coords
x_pos -= ((digit.shape[0]-1)/2)
y_pos -= ((digit.shape[1]-1)/2)
print (x_pos[:3],x_pos[-3:])
print (y_pos[:3],y_pos[-3:])
# normalize pixel coords
x_pos /= ((digit.shape[0]-1)/2)
y_pos /= ((digit.shape[1]-1)/2)
print (x_pos[:3],x_pos[-3:])
print (y_pos[:3],y_pos[-3:])
# one entry for every pixel in (x,y,1) format
xs, ys = np.meshgrid(x_pos, y_pos)
print (xs.shape, xs[:3,:3])
print (ys.shape, ys[:3,:3])
# now merge them all into a list of pixel coords
coordinates = np.stack((xs.flatten(), ys.flatten(), np.ones(len(x_pos)*len(y_pos))), axis=1)
print (coordinates.shape, coordinates[:3])
# this variable if fixed for all images, the list of initial pixel locations
coordinates = torch.from_numpy(coordinates).float()
axes = [
[1,0,0],
[0,1,0],
[0,0,1]
]
digit = torch.from_numpy(digit).unsqueeze(0).unsqueeze(0).float()
print (digit.size())
for axis in axes:
axis = torch.tensor(axis).float()
for theta in np.linspace(-90,90,100):
# manual deg 2 rad
theta = torch.tensor([theta * np.pi / 180])
rot = rotation_matrix(axis, theta)
# apply rotation to each pixel coordinate
new_locations = torch.matmul(coordinates, rot)
new_locations[:,0] /= new_locations[:,2] # apply z for perspective instead of orthographic cam
new_locations[:,1] /= new_locations[:,2] # apply z for perspective instead of orthographic cam
# cut off the z axis and reshape
flow_grid = new_locations[:,:2].reshape((1,28,28,2)) # the 2 at the end are the source pixel coords
# sample original image with this flow field, i.e. reproject
out = torch.nn.functional.grid_sample(digit,flow_grid).squeeze(0).squeeze(0).numpy()
# display the image
cv2.imshow("ladeeda",out)
cv2.waitKey(50) # this is required, otherwise the opencv window will close immediately
@qianyizhang
Copy link

nice work

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment