Created
October 12, 2018 16:41
-
-
Save yangyushi/3841d8909829f800e6f095e5e0419947 to your computer and use it in GitHub Desktop.
Find the best rotation and dilatation between points, according to 10.1107/S0567739476001873
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from mpl_toolkits.mplot3d import Axes3D | |
from scipy.optimize import least_squares | |
def get_best_rotation(r1, r2): | |
""" | |
calculate the best rotation to relate two sets of vectors | |
see the paper [A solution for the best rotation to relate two sets of vectors] for detail | |
all the points were treated equally, which means w_n = 0 (in the paper) | |
""" | |
R = r2.T @ r1 | |
u, a = np.linalg.eig(R.T @ R) # a[:, i] corresponds to u[i] | |
b = 1 / np.sqrt(u) * (R @ a) | |
return (b @ a.T).T | |
def get_best_dilatation_rotation(r1, r2, init_guess=None): | |
"""calculate numeratically""" | |
if isinstance(init_guess, type(None)): | |
init_guess = np.ones(r1.shape[1]) | |
def cost(L, r1, r2): | |
Lambda = np.identity(r1.shape[1]) * np.array(L) | |
r1t = r1 @ Lambda | |
R = get_best_rotation(r1t, r2) | |
return np.sum(np.linalg.norm(r2 - r1t @ R)) | |
result = least_squares(cost, init_guess, args=(r1, r2)) | |
L = np.identity(r1.shape[1]) * np.array(result['x']) | |
r1t = r1 @ L | |
R = get_best_rotation(r1t, r2) | |
return L, R | |
def test_2d(): | |
theta = np.random.uniform(-np.pi, np.pi) | |
rot_true = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) | |
a = (np.random.random((25, 2)) - 0.5) * 40 | |
b = a @ rot_true | |
b += np.random.random(b.shape) * np.max(b) / 25 | |
rot = get_best_rotation(a, b) | |
plt.scatter(*a.T, color='tomato', facecolor='w', label='r1') | |
plt.scatter(*b.T, color='teal', facecolor='w', label='r2', s=40) | |
plt.scatter(*(a @ rot).T, color='tomato', label='r1 rotated', s=20) | |
plt.legend() | |
plt.show() | |
def test_3d(): | |
rot_true = np.linalg.qr(np.random.random((3, 3)))[0] | |
a = (np.random.random((50, 3)) - 0.5) * 40 | |
b = a @ rot_true | |
b += np.random.random(b.shape) * b.max() / 25 | |
rot = get_best_rotation(a, b) | |
fig = plt.figure() | |
ax = fig.add_subplot(111, projection='3d') | |
ax.scatter(*a.T, color='tomato', facecolor='w', label='r1') | |
ax.scatter(*b.T, color='teal', facecolor='w', label='r2', s=40) | |
ax.scatter(*(a @ rot).T, color='tomato', label='r1 rotated', s=20) | |
plt.legend() | |
plt.show() | |
def test_2d_dilatation(): | |
theta = np.random.uniform(-np.pi, np.pi) | |
rot_true = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) | |
dila_true = np.identity(2) * np.array([3, 5]) | |
a = (np.random.random((25, 2)) - 0.5) * 40 | |
b = (a @ dila_true) @ rot_true | |
b += np.random.random(b.shape) * np.max(b) / 25 # add noise | |
dila, rot = get_best_dilatation_rotation(a, b, init_guess = [1, 1]) | |
print(dila - dila_true) | |
plt.scatter(*a.T, color='tomato', facecolor='w', label='r1') | |
plt.scatter(*b.T, color='teal', facecolor='w', label='r2', s=40) | |
plt.scatter(*((a @ dila) @ rot).T, color='tomato', label='r1 dilated & rotated', s=20) | |
plt.legend() | |
plt.show() | |
def test_3d_dilatation(): | |
rot_true = np.linalg.qr(np.random.random((3, 3)))[0] | |
dila_true = np.identity(3) * np.array([3, 5, 10]) | |
a = (np.random.random((50, 3)) - 0.5) * 40 | |
b = (a @ dila_true) @ rot_true # transform | |
b += np.random.random(b.shape) * np.max(b) / 25 # add noise | |
dila, rot = get_best_dilatation_rotation(a, b) | |
print(dila - dila_true) | |
fig = plt.figure() | |
ax = fig.add_subplot(111, projection='3d') | |
ax.scatter(*a.T, color='tomato', facecolor='w', label='r1') | |
ax.scatter(*b.T, color='teal', facecolor='w', label='r2', s=40) | |
ax.scatter(*((a @ dila) @ rot).T, color='tomato', label='r1 dilated & rotated', s=20) | |
plt.legend() | |
plt.show() | |
if __name__ == '__main__': | |
test_2d_dilatation() | |
test_3d_dilatation() | |
test_3d() | |
test_2d() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment