Skip to content

Instantly share code, notes, and snippets.

@ljleb
Created November 28, 2024 16:50
Show Gist options
  • Save ljleb/9947a2212fe22de20ecb322508e69ce1 to your computer and use it in GitHub Desktop.
Save ljleb/9947a2212fe22de20ecb322508e69ce1 to your computer and use it in GitHub Desktop.
import torch
import math
torch.manual_seed(0)
device = torch.device("cuda:0")
dtype = torch.float64
def randn_q(*s):
return torch.linalg.qr(randn(*s)).Q
def randn(*s):
return torch.randn(s, device=device, dtype=dtype)
m, n = 768, 768
a, b = randn(m, n), randn(m, n)
# ensure a and b have the same determinant sign
if a.slogdet()[0] != b.slogdet()[0]:
a[-1] *= -1
alpha = torch.rand(n, n, device=device, dtype=dtype)
u, s, vt = torch.linalg.svd(a.T @ b, full_matrices=False, driver="gesvd")
# initialize `w` as the isotropic solution
w = (u @ vt).requires_grad_()
del u, s, vt
def objective():
return 2 * (((a @ w - b).square() * alpha).mean() + ((a - b @ w.T).square() * (1 - alpha)).mean())
# weight update according to steepest direction of descent
def next_w(iterations=100):
w_detached = w.detach()
m_sharp = w_detached.T @ w.grad - w.grad.T @ w_detached
m_sharp /= torch.frobenius_norm(m_sharp, dim=(-2, -1))
for _ in range(iterations):
m_sharp = 3/2 * m_sharp - 1/2 * m_sharp @ m_sharp.T @ m_sharp
return (w_detached - lr * w_detached @ m_sharp) / math.sqrt(1 + lr**2)
lr = 1e-2
iterations = 1000
for i in range(iterations):
w.grad = None
loss = objective()
if (i + 1) % 10 == 0:
print(f"loss: {loss.item():0.6f}")
loss.backward()
w = next_w().requires_grad_()
# `w` is the optimized orthogonal map from `a` to `b` according to the weights in `alpha`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment