Created
October 31, 2020 18:04
-
-
Save DuaneNielsen/9e0b3ab7d5a614880f7bd478a92b0f42 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from matplotlib import pyplot as plt | |
def cross_matrix(axis): | |
""" | |
Returns the skew symetric matrix given a vector omega | |
:param omega: 3D axis of rotation | |
:return: skew symmetric matrix for axis | |
given x, y, z returns | |
[0, -z, y] | |
[z, 0, -x] | |
[-y, x, 0] | |
""" | |
cross = torch.zeros(9).scatter(0, torch.tensor([5, 2, 1]), axis).reshape(3, 3) | |
cross = cross + cross.T | |
return cross * torch.tensor([ | |
[0, -1, 1], | |
[1, 0, -1], | |
[-1, 1, 0] | |
], dtype=axis.dtype) | |
def rodriguez(axis, theta): | |
""" | |
rotation matrix around axis by angle theta | |
:param axis: vector that points along the axis | |
:param theta: the angle of rotation around the axis | |
:return: rotation matrix | |
""" | |
axis = axis / axis.norm() | |
cross = cross_matrix(axis) | |
return torch.eye(3) + torch.sin(theta) * cross + (1 - torch.cos(theta)) * cross.matmul(cross) | |
def vector(x, y, z): | |
return torch.tensor([x, y, z], dtype=torch.float) | |
def translation(axis, theta, v): | |
R = rodriguez(axis, theta) | |
T = torch.eye(4, 4) | |
T[0:3, 0:3] = R | |
T[0:3, 3] = v | |
return T | |
axis = vector(x=0, y=0, z=1) | |
time = torch.linspace(0, 10, 100) | |
body = torch.tensor([ | |
[0, 1, 0, 1], | |
[1, 0, 0, 1], | |
[0,-1, 0, 1], | |
[-1,0, 0, 1], | |
[0, 1, 0, 1] | |
], dtype=torch.float).T | |
plt.ion() | |
for t in time: | |
v = torch.tensor([t.item(), t.item(), 0]) | |
T = translation(axis, t, v) | |
c = torch.matmul(T, torch.tensor([0, 0, 0, 1.0])) | |
pos = torch.matmul(T, body) | |
plt.xlim(x * 10 for x in plt.xlim()) | |
plt.ylim(y * 10 for y in plt.ylim()) | |
plt.plot(pos[0], pos[1]) | |
plt.draw() | |
plt.pause(0.05) | |
plt.cla() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment