Created
January 29, 2020 12:16
-
-
Save mukheshpugal/30894f69f2fc75954c17db428e93330d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Exp_SE3(Function): | |
""" | |
Implements exp SO3 as a torch fuction | |
""" | |
@staticmethod | |
def forward(ctx, vec: "Tensor[6]") -> "Tensor[4,4]": | |
""" | |
:param ctx: context, to store some tensors | |
for backward | |
:param vec: Tensor[6] | |
:return: | |
""" | |
mat = Pose.exp_se3(vec) | |
ctx.mat = mat | |
flat_mat = mat[:3].reshape(-1) | |
return flat_mat | |
@staticmethod | |
def backward(ctx, grad_output): | |
""" | |
Ref: | |
:param ctx: | |
:param grad_output: | |
:return: | |
""" | |
grad_pose = torch.zeros(12, 6, requires_grad=True) | |
grad_pose[0:3, 3:6] = -Pose._hat(ctx.mat[:3, 0]) | |
grad_pose[3:6, 3:6] = -Pose._hat(ctx.mat[:3, 1]) | |
grad_pose[6:9, 3:6] = -Pose._hat(ctx.mat[:3, 2]) | |
grad_pose[9:12, 3:6] = -Pose._hat(ctx.mat[:3, 3]) | |
grad_pose[9:12, 0:3] = torch.eye(3) | |
return grad_output @ grad_pose |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment