Created
October 9, 2019 18:13
-
-
Save gatheluck/e7cd12ad9494fd4f02b02bff5fcb813f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
import torch | |
import torch.nn.functional as F | |
import pymesh | |
class VisCameras(): | |
init_front_vec = torch.tensor([1.0, 0.0, 0.0]).view(1,-1).float() | |
init_up_vec = torch.tensor([0.0, 1.0, 0.0]).view(1,-1).float() | |
init_scale = torch.tensor([1.0]).view(1,-1).float() | |
default_obj_path = os.path.join('camera_aligned.obj') | |
def __init__(self, obj_path=None): | |
if obj_path == None: | |
self.obj_path = self.default_obj_path | |
else: | |
self.obj_path = obj_path | |
self.reset(self.obj_path) | |
def reset(self, obj_path): | |
self.front_vec = self.init_front_vec | |
self.up_vec = self.init_up_vec | |
self.scale = self.init_scale | |
camera_shape = pymesh.load_mesh(obj_path) | |
self.verts = torch.from_numpy(camera_shape.vertices) # (702, 3) | |
self.faces = torch.from_numpy(camera_shape.faces) # (1400,3) | |
self.num_vert = self.verts.size(0) | |
self.num_face = self.faces.size(0) | |
self.verts -= torch.mean(self.verts, dim=0) | |
def _quat_rotate(self, X, q): | |
"""Rotate points by quaternions. | |
Args: | |
X: B X N X 3 points | |
q: B X 4 quaternions | |
Returns: | |
X_rot: B X N X 3 (rotated points) | |
""" | |
# repeat q along 2nd dim | |
ones_x = X[[0], :, :][:, :, [0]] * 0 + 1 | |
q = torch.unsqueeze(q, 1) * ones_x | |
q_conj = torch.cat([q[:, :, [0]], -1 * q[:, :, 1:4]], dim=-1) | |
X = torch.cat([X[:, :, [0]] * 0, X], dim=-1) | |
X_rot = self._hamilton_product(q, self._hamilton_product(X, q_conj)) | |
return X_rot[:, :, 1:4] | |
def _hamilton_product(self, qa, qb): | |
"""Multiply qa by qb. | |
Args: | |
qa: B X N X 4 quaternions | |
qb: B X N X 4 quaternions | |
Returns: | |
q_mult: B X N X 4 | |
""" | |
qa_0 = qa[:, :, 0] | |
qa_1 = qa[:, :, 1] | |
qa_2 = qa[:, :, 2] | |
qa_3 = qa[:, :, 3] | |
qb_0 = qb[:, :, 0] | |
qb_1 = qb[:, :, 1] | |
qb_2 = qb[:, :, 2] | |
qb_3 = qb[:, :, 3] | |
# See https://en.wikipedia.org/wiki/Quaternion#Hamilton_product | |
q_mult_0 = qa_0 * qb_0 - qa_1 * qb_1 - qa_2 * qb_2 - qa_3 * qb_3 | |
q_mult_1 = qa_0 * qb_1 + qa_1 * qb_0 + qa_2 * qb_3 - qa_3 * qb_2 | |
q_mult_2 = qa_0 * qb_2 - qa_1 * qb_3 + qa_2 * qb_0 + qa_3 * qb_1 | |
q_mult_3 = qa_0 * qb_3 + qa_1 * qb_2 - qa_2 * qb_1 + qa_3 * qb_0 | |
return torch.stack([q_mult_0, q_mult_1, q_mult_2, q_mult_3], dim=-1) | |
def _get_quat_rotate(self, axis, radian): | |
""" | |
axis: (B, 3) | |
radian: (B, 1) | |
""" | |
assert axis.size(0)==radian.size(0) | |
assert axis.size(1)==3 | |
assert radian.size(1)==1 | |
assert len(axis.size())==len(radian.size())==2 | |
w = torch.cos(radian/2.0).float() | |
x = (axis[:,0].view(-1,1) * torch.sin(radian/2.0)).float() | |
y = (axis[:,1].view(-1,1) * torch.sin(radian/2.0)).float() | |
z = (axis[:,2].view(-1,1) * torch.sin(radian/2.0)).float() | |
quat = torch.cat([w,x,y,z], dim=-1) | |
return quat | |
def _get_quat_between_vec(self, trg_vec, src_vec): | |
""" | |
- trg_vec (B, 3) | |
- src_vec (B, 3) | |
""" | |
assert trg_vec.size(0)==src_vec.size(0) | |
assert trg_vec.size(1)==src_vec.size(1)==3 | |
assert len(trg_vec.size())==len(src_vec.size())==2 | |
B = trg_vec.size(0) | |
trg_vec = F.normalize(trg_vec, dim=-1) | |
src_vec = F.normalize(src_vec, dim=-1) | |
# rotation axis | |
axis = F.normalize(torch.cross(trg_vec, src_vec, dim=-1), dim=-1) #(B,3) | |
# rotation angle | |
dot = torch.bmm(trg_vec.view(B,1,-1), src_vec.view(B,-1,1)).view(-1,1) #(B,1) | |
radian = -torch.acos(dot) #(B,1) | |
return self._get_quat_rotate(axis, radian) | |
def get_verts_and_faces(self, pos, look_at_pos, pred_quat, scale=1.0): | |
""" | |
- pos (B, 3): camera pos | |
- look_at_pos (B, 3): look at pos from pos | |
- pred_quat(B, 4): predicted camera quat | |
""" | |
assert pos.size(0)==look_at_pos.size(0)==pred_quat.size(0) | |
assert pos.size(1)==look_at_pos.size(1)==3 | |
assert pred_quat.size(1)==4 | |
assert len(pos.size())==len(look_at_pos.size())==len(pred_quat.size())==2 | |
B = pos.size(0) | |
# look at | |
trg_vec = look_at_pos - pos | |
src_vec = self.init_front_vec.repeat(B,1) | |
up_vec = self.init_up_vec.repeat(B, 1) | |
quat_look_at = self._get_quat_between_vec(trg_vec, src_vec).float() # (B,4) | |
# up | |
trg_vec = torch.cross(trg_vec, torch.cross(-trg_vec, up_vec)) | |
src_vec = self._quat_rotate(up_vec.view(B,1,-1), quat_look_at).view(B,-1).float() | |
quat_up = self._get_quat_between_vec(trg_vec, src_vec).float() # (B,4) | |
verts = self.verts.repeat(B,1,1).float() | |
verts = self._quat_rotate(verts, quat_look_at) | |
verts = self._quat_rotate(verts, quat_up) | |
verts = self._quat_rotate(verts, pred_quat) | |
# translate | |
verts += pos.view(B,1,3).repeat(1,verts.size(1),1) | |
faces = self.faces.repeat(B,1,1) | |
return verts, faces | |
if __name__ == "__main__": | |
import soft_renderer as sr | |
def save_as_obj(verts, faces, output_path, verbose = False): | |
assert len(verts.shape) == 3 | |
assert len(faces.shape) == 3 | |
if torch.cuda.is_available() is not True: | |
return None # soft renderer is only supported under cuda | |
else: | |
if verbose: print("saving as obj...") | |
# prepare for output | |
output_path = os.path.splitext(output_path)[0] + '.obj' # replace extention by .obj | |
os.makedirs(os.path.dirname(output_path), exist_ok=True) # make output dir | |
if verbose: print("output_path: {}".format(output_path)) | |
# make mesh | |
mesh = sr.Mesh(verts[0,:,:], faces[0,:,:]) | |
mesh.save_obj(output_path) | |
camera = VisCameras(os.path.join('data','camera_aligned.obj')) | |
pos = torch.tensor([[1.,2.,3.], [3.,1.,4.]]).view(2,-1).float() | |
look_at_pos = torch.tensor([[0.,0.,0.], [0.,0.,0.]]).view(2,-1).float() | |
pred_quat = torch.tensor([[1.,0.,0.,0.], [1.,0.,0.,0.]]).view(2,-1).float() | |
verts, faces = camera.get_verts_and_faces(pos, look_at_pos, pred_quat) | |
for i in range(2): | |
save_as_obj(verts[i,:,:].unsqueeze(0), faces[i,:,:].unsqueeze(0), os.path.join('data','camera_out_{}'.format(i))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment