Skip to content

Instantly share code, notes, and snippets.

@thomasahle
Created April 28, 2023 19:09
Show Gist options
  • Save thomasahle/0a53cecd24403b8c3f3634249f4da725 to your computer and use it in GitHub Desktop.
Save thomasahle/0a53cecd24403b8c3f3634249f4da725 to your computer and use it in GitHub Desktop.
import numpy as np
def quartic(y, n1, n2, ip):
"""A solution to the equation
n1 x2 + n2 (1 - x2) + 2 ip sqrt(x2 (1 - x2)) == y
"""
assert n2 <= y <= n1
d = np.sign(ip) * (ip**4 + ip**2 * (n1 - y) * (y - n2)) ** 0.5
x2 = (2 * ip**2 + (n1 - n2) * (y - n2) - 2 * d) / (4 * ip**2 + (n1 - n2) ** 2)
assert np.isclose(n1 * x2 + n2 * (1 - x2) + 2 * ip * (x2 * (1 - x2)) ** 0.5, y)
return x2
def optimal_rotation(X):
colnorm2s = np.einsum("ij,ij->j", X, X)
target_norm = np.mean(colnorm2s)
R = np.eye(d)
X = X.copy()
for i in range(d - 1):
# Find indices to rotate
lo, hi = np.argmin(colnorm2s), np.argmax(colnorm2s)
# Compute optimal rotation
n1, n2, ip = colnorm2s[hi], colnorm2s[lo], X[:, hi] @ X[:, lo]
cos2 = quartic(target_norm, n1, n2, ip)
cos, sin = cos2**0.5, (1 - cos2) ** 0.5
rot = np.array([[cos, -sin], [sin, cos]])
# Update R and X (so we can compute the correct inner next time)
X[:, (hi, lo)] = X[:, (hi, lo)] @ rot
R[:, (hi, lo)] = R[:, (hi, lo)] @ rot
colnorm2s[hi] = target_norm
colnorm2s[lo] = n2 * cos2 + n1 * (1 - cos2) - 2 * ip * cos * sin
# Test that the new norms match what we expected
new_n1, new_n2 = np.einsum("ij,ij->j", X[:, (hi, lo)], X[:, (hi, lo)])
assert np.isclose(new_n1, target_norm)
assert np.isclose(new_n2, colnorm2s[lo])
return R
# Example usage:
n, d = 100, 13
X = np.random.randn(n, d)
R = optimal_rotation(X)
assert np.allclose(R @ R.T, np.eye(d)), "R not a rotation"
old_norms = np.linalg.norm(X, axis=0)
new_norms = np.linalg.norm(X @ R, axis=0)
# print(f"{old_norms=}")
print(f"{old_norms.max()=}")
# print(f"{new_norms=}")
print(f"{new_norms.max()=}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment