Skip to content

Instantly share code, notes, and snippets.

@KyanainsGate
Created December 1, 2022 22:11
Show Gist options
  • Save KyanainsGate/b931fc753fe762c11245e920b47a4166 to your computer and use it in GitHub Desktop.
Save KyanainsGate/b931fc753fe762c11245e920b47a4166 to your computer and use it in GitHub Desktop.
Pytorch3d demonstration for tensor learning
"""
Python: 3.9.5
torch: 1.12.1+cu113
pytorch3d: 0.7.1
"""
import torch
from torch import nn
from pytorch3d.transforms.rotation_conversions import quaternion_to_matrix, matrix_to_quaternion
def xyzquat_to_bx3x4(bx7: torch.Tensor) -> torch.Tensor:
xyz, quat = bx7[:, :3], bx7[:, 3:] # [B, 3], [B, 4]
bx3x3 = quaternion_to_matrix(quat)
return torch.concat([bx3x3, xyz.unsqueeze(2)], 2)
def bx3x4_to_xyzquat(bx3x4: torch.Tensor) -> torch.Tensor:
xyz, rot = bx3x4[:, :3, 3], bx3x4[:, :3, :3] # [B, 3], [B, 3, 3]
quat = matrix_to_quaternion(rot)
return torch.concat([xyz, quat], 1)
class PoseTensorSuccess(nn.Module):
def __init__(self, x_y_z_quat: torch.Tensor):
super(PoseTensorSuccess, self).__init__()
self.pose_param = nn.Parameter(x_y_z_quat, requires_grad=True)
def forward(self):
return xyzquat_to_bx3x4(self.pose_param)
class PoseTensorFail(nn.Module):
def __init__(self, x_y_z_quat: torch.Tensor):
super(PoseTensorFail, self).__init__()
self.pose_param = nn.Parameter(x_y_z_quat, requires_grad=True)
self.bx3x4 = xyzquat_to_bx3x4(
self.pose_param) # computational graph will be duplicated to bx3x4, that's why it will fail
def forward(self):
return self.bx3x4
if __name__ == '__main__':
init_pose = torch.tensor([[-4.6580, 0.7586, 0.1399, 0.0582, 0.7342, 0.0534, 0.6744]])
tgt_pose = torch.tensor([[-4.9384, 0.7019, 0.6722, 0.0434, 0.6955, 0.0446, 0.7158]])
model = PoseTensorSuccess(init_pose)
# model = PoseTensorFail(init_pose)
learning_rate = 1e-3
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)
loss_fn = torch.nn.MSELoss(reduction='mean')
torch.autograd.set_detect_anomaly(True) # Debug purpose
for t in range(5000):
with torch.set_grad_enabled(True):
y_pred = model()
loss = loss_fn(input=y_pred, target=xyzquat_to_bx3x4(tgt_pose))
if t % 1000 == 0:
print(t, loss.item())
print("y_pred\n", bx3x4_to_xyzquat(y_pred).data)
loss.backward()
optimizer.step()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment