Skip to content

Instantly share code, notes, and snippets.

@gatheluck
Created October 9, 2019 18:13
Show Gist options
  • Save gatheluck/e7cd12ad9494fd4f02b02bff5fcb813f to your computer and use it in GitHub Desktop.
Save gatheluck/e7cd12ad9494fd4f02b02bff5fcb813f to your computer and use it in GitHub Desktop.
import os
import sys
import torch
import torch.nn.functional as F
import pymesh
class VisCameras():
init_front_vec = torch.tensor([1.0, 0.0, 0.0]).view(1,-1).float()
init_up_vec = torch.tensor([0.0, 1.0, 0.0]).view(1,-1).float()
init_scale = torch.tensor([1.0]).view(1,-1).float()
default_obj_path = os.path.join('camera_aligned.obj')
def __init__(self, obj_path=None):
if obj_path == None:
self.obj_path = self.default_obj_path
else:
self.obj_path = obj_path
self.reset(self.obj_path)
def reset(self, obj_path):
self.front_vec = self.init_front_vec
self.up_vec = self.init_up_vec
self.scale = self.init_scale
camera_shape = pymesh.load_mesh(obj_path)
self.verts = torch.from_numpy(camera_shape.vertices) # (702, 3)
self.faces = torch.from_numpy(camera_shape.faces) # (1400,3)
self.num_vert = self.verts.size(0)
self.num_face = self.faces.size(0)
self.verts -= torch.mean(self.verts, dim=0)
def _quat_rotate(self, X, q):
"""Rotate points by quaternions.
Args:
X: B X N X 3 points
q: B X 4 quaternions
Returns:
X_rot: B X N X 3 (rotated points)
"""
# repeat q along 2nd dim
ones_x = X[[0], :, :][:, :, [0]] * 0 + 1
q = torch.unsqueeze(q, 1) * ones_x
q_conj = torch.cat([q[:, :, [0]], -1 * q[:, :, 1:4]], dim=-1)
X = torch.cat([X[:, :, [0]] * 0, X], dim=-1)
X_rot = self._hamilton_product(q, self._hamilton_product(X, q_conj))
return X_rot[:, :, 1:4]
def _hamilton_product(self, qa, qb):
"""Multiply qa by qb.
Args:
qa: B X N X 4 quaternions
qb: B X N X 4 quaternions
Returns:
q_mult: B X N X 4
"""
qa_0 = qa[:, :, 0]
qa_1 = qa[:, :, 1]
qa_2 = qa[:, :, 2]
qa_3 = qa[:, :, 3]
qb_0 = qb[:, :, 0]
qb_1 = qb[:, :, 1]
qb_2 = qb[:, :, 2]
qb_3 = qb[:, :, 3]
# See https://en.wikipedia.org/wiki/Quaternion#Hamilton_product
q_mult_0 = qa_0 * qb_0 - qa_1 * qb_1 - qa_2 * qb_2 - qa_3 * qb_3
q_mult_1 = qa_0 * qb_1 + qa_1 * qb_0 + qa_2 * qb_3 - qa_3 * qb_2
q_mult_2 = qa_0 * qb_2 - qa_1 * qb_3 + qa_2 * qb_0 + qa_3 * qb_1
q_mult_3 = qa_0 * qb_3 + qa_1 * qb_2 - qa_2 * qb_1 + qa_3 * qb_0
return torch.stack([q_mult_0, q_mult_1, q_mult_2, q_mult_3], dim=-1)
def _get_quat_rotate(self, axis, radian):
"""
axis: (B, 3)
radian: (B, 1)
"""
assert axis.size(0)==radian.size(0)
assert axis.size(1)==3
assert radian.size(1)==1
assert len(axis.size())==len(radian.size())==2
w = torch.cos(radian/2.0).float()
x = (axis[:,0].view(-1,1) * torch.sin(radian/2.0)).float()
y = (axis[:,1].view(-1,1) * torch.sin(radian/2.0)).float()
z = (axis[:,2].view(-1,1) * torch.sin(radian/2.0)).float()
quat = torch.cat([w,x,y,z], dim=-1)
return quat
def _get_quat_between_vec(self, trg_vec, src_vec):
"""
- trg_vec (B, 3)
- src_vec (B, 3)
"""
assert trg_vec.size(0)==src_vec.size(0)
assert trg_vec.size(1)==src_vec.size(1)==3
assert len(trg_vec.size())==len(src_vec.size())==2
B = trg_vec.size(0)
trg_vec = F.normalize(trg_vec, dim=-1)
src_vec = F.normalize(src_vec, dim=-1)
# rotation axis
axis = F.normalize(torch.cross(trg_vec, src_vec, dim=-1), dim=-1) #(B,3)
# rotation angle
dot = torch.bmm(trg_vec.view(B,1,-1), src_vec.view(B,-1,1)).view(-1,1) #(B,1)
radian = -torch.acos(dot) #(B,1)
return self._get_quat_rotate(axis, radian)
def get_verts_and_faces(self, pos, look_at_pos, pred_quat, scale=1.0):
"""
- pos (B, 3): camera pos
- look_at_pos (B, 3): look at pos from pos
- pred_quat(B, 4): predicted camera quat
"""
assert pos.size(0)==look_at_pos.size(0)==pred_quat.size(0)
assert pos.size(1)==look_at_pos.size(1)==3
assert pred_quat.size(1)==4
assert len(pos.size())==len(look_at_pos.size())==len(pred_quat.size())==2
B = pos.size(0)
# look at
trg_vec = look_at_pos - pos
src_vec = self.init_front_vec.repeat(B,1)
up_vec = self.init_up_vec.repeat(B, 1)
quat_look_at = self._get_quat_between_vec(trg_vec, src_vec).float() # (B,4)
# up
trg_vec = torch.cross(trg_vec, torch.cross(-trg_vec, up_vec))
src_vec = self._quat_rotate(up_vec.view(B,1,-1), quat_look_at).view(B,-1).float()
quat_up = self._get_quat_between_vec(trg_vec, src_vec).float() # (B,4)
verts = self.verts.repeat(B,1,1).float()
verts = self._quat_rotate(verts, quat_look_at)
verts = self._quat_rotate(verts, quat_up)
verts = self._quat_rotate(verts, pred_quat)
# translate
verts += pos.view(B,1,3).repeat(1,verts.size(1),1)
faces = self.faces.repeat(B,1,1)
return verts, faces
if __name__ == "__main__":
import soft_renderer as sr
def save_as_obj(verts, faces, output_path, verbose = False):
assert len(verts.shape) == 3
assert len(faces.shape) == 3
if torch.cuda.is_available() is not True:
return None # soft renderer is only supported under cuda
else:
if verbose: print("saving as obj...")
# prepare for output
output_path = os.path.splitext(output_path)[0] + '.obj' # replace extention by .obj
os.makedirs(os.path.dirname(output_path), exist_ok=True) # make output dir
if verbose: print("output_path: {}".format(output_path))
# make mesh
mesh = sr.Mesh(verts[0,:,:], faces[0,:,:])
mesh.save_obj(output_path)
camera = VisCameras(os.path.join('data','camera_aligned.obj'))
pos = torch.tensor([[1.,2.,3.], [3.,1.,4.]]).view(2,-1).float()
look_at_pos = torch.tensor([[0.,0.,0.], [0.,0.,0.]]).view(2,-1).float()
pred_quat = torch.tensor([[1.,0.,0.,0.], [1.,0.,0.,0.]]).view(2,-1).float()
verts, faces = camera.get_verts_and_faces(pos, look_at_pos, pred_quat)
for i in range(2):
save_as_obj(verts[i,:,:].unsqueeze(0), faces[i,:,:].unsqueeze(0), os.path.join('data','camera_out_{}'.format(i)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment