Created
December 26, 2019 05:38
-
-
Save Alescontrela/d179e8067860868ff95dada6911099cb to your computer and use it in GitHub Desktop.
Using SO(3) rotation manifolds to build a Christmas tree.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Quick script to demonstrate SO(3) rotation manifolds via axis-angle. | |
Merry christmas :^) | |
@Author: Alejandro Escontrela | |
""" | |
import numpy as np | |
from matplotlib import pyplot as plt | |
from mpl_toolkits.mplot3d import Axes3D | |
from matplotlib.patches import FancyArrowPatch | |
from mpl_toolkits.mplot3d import proj3d | |
class Arrow3D(FancyArrowPatch): | |
"""Helper class to plot an arrow in 3D.""" | |
def __init__(self, xs, ys, zs, *args, **kwargs): | |
FancyArrowPatch.__init__(self, (0,0), (0,0), *args, **kwargs) | |
self._verts3d = xs, ys, zs | |
def draw(self, renderer): | |
xs3d, ys3d, zs3d = self._verts3d | |
xs, ys, _ = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M) | |
self.set_positions((xs[0],ys[0]),(xs[1],ys[1])) | |
FancyArrowPatch.draw(self, renderer) | |
vec_g = np.array([3., 0., 0.6]) # Vector in the global frame. | |
def exp_map(w: np.array, theta: float): | |
"""Obtain the exponential map for SO(3) for a given axis-angle rotation.""" | |
# Obtain the 3DoF axis-angle vector. | |
zeta = w * theta | |
# Obtain skew-symmetric representation of zeta. | |
zeta_sk = np.roll(np.roll(np.diag(zeta.flatten()), 1, 1), -1, 0) | |
zeta_sk = zeta_sk - zeta_sk.T | |
# Apply Rodrigues' forumula. Note that Rodrigues' formula is the closed form | |
# of the matrix exponential in SO(3), very cool! | |
zeta_exp = ( | |
np.identity(3) + | |
(np.sin(theta) / theta) * zeta_sk + | |
((1 - np.cos(theta)) / theta ** 2) * np.matmul(zeta_sk, zeta_sk) | |
) | |
return zeta_exp | |
# Plot a unit sphere of rotations. | |
fig = plt.figure(figsize=(10,8)) | |
ax = fig.add_subplot(111, projection='3d') | |
# The rotation axis. | |
w = np.array([1, 0, 0]) | |
# Define rotation via axis angle representation. Notice that this representation | |
# is minimal, as it only contains 3DoF (the unit vector constraint on w removes | |
# one DoF). | |
n_increments = 50 # Number of theta increments from 0 to 2 * pi. | |
for theta in np.linspace(1e-3, 2 * np.pi, n_increments): | |
# The exponential map is our retraction that maps between the tangent | |
# space and SO(3). Funny enough, this is called "lifting". In practice, | |
# the retraction is used to optimize a value in the manifold via the | |
# tangent space, which is easier to work in. | |
R = exp_map(w, theta) | |
# The vector, rotated by the given axis angle rotation. | |
vec_n = np.matmul(R, vec_g) | |
# Insert it into the plot! | |
ax.add_artist(Arrow3D( | |
[0, vec_n[0]], [0, vec_n[1]],[0, vec_n[2]], mutation_scale=20, | |
lw=3, arrowstyle="-|>", color="g", alpha=0.6)) | |
ax.set_xlabel('x') | |
ax.set_ylabel('y') | |
ax.set_zlabel('z') | |
# Plot the "lights" and "ornaments". Visual of what's happening: | |
# https://en.wikipedia.org/wiki/File:ComplexSinInATimeAxe.gif | |
omega_ornaments = 15 | |
omega_lights = 20 | |
r_i = 0. | |
r_f = max(vec_g[1], vec_g[2]) | |
ts = np.linspace(0, vec_g[0], 100) | |
# For the lights. | |
ys_lights = [] | |
zs_lights = [] | |
for t in ts: | |
r = r_f * t / (vec_g[0]) + r_i | |
# Ornament. | |
phi_orn = omega_ornaments * t | |
y_orn = r * np.cos(phi_orn) | |
z_orn = r * np.sin(phi_orn) | |
ax.plot([t], [y_orn], [z_orn], 'o') | |
# Light. | |
phi_light = omega_lights * t | |
ys_lights.append(r * np.cos(phi_light)) | |
zs_lights.append(r * np.sin(phi_light)) | |
ax.plot(ts, ys_lights, zs_lights, '--', color='r', alpha = 0.8) | |
# The star on the christmas tree :'^ ) | |
ax.plot([0], [0], [0], 'x', markersize=12, color='white', linewidth=5) | |
ax.set_facecolor('xkcd:salmon') | |
# Hide grid lines | |
ax.grid(False) | |
# Hide axes ticks | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
ax.set_zticks([]) | |
ax.set_xlim((min(-1, vec_g[0]), max(1, vec_g[0]))) | |
ax.set_ylim((min(-1, vec_g[1]), max(1, vec_g[1]))) | |
ax.set_zlim((min(-1, vec_g[2]), max(1, vec_g[2]))) | |
plt.draw() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment