Skip to content

Instantly share code, notes, and snippets.

@cloneofsimo
Last active November 14, 2024 20:47
Show Gist options
  • Save cloneofsimo/fb5b0a5d9dea82c8d8468609204ef8ee to your computer and use it in GitHub Desktop.
Save cloneofsimo/fb5b0a5d9dea82c8d8468609204ef8ee to your computer and use it in GitHub Desktop.
Orthogonal weight update
https://x.com/jxbz/status/1857145985480438073
import torch
def polar_factor_newton_schulz(M, max_iter=50):
M_t = M / M.norm(p='fro')
for _ in range(max_iter):
M_t = 1.5 * M_t - 0.5 * M_t @ M_t.T @ M_t
return M_t
def stiefel_update(W, G, eta, max_iter=10):
M = W.T @ G - G.T @ W
scaling = torch.sqrt(torch.tensor(1 + eta**2))
M_hat = polar_factor_newton_schulz(M, max_iter=max_iter)
W_new = W @ (torch.eye(W.shape[1]) - eta * M_hat) / scaling
return W_new
m, n = 50, 30
seed = 11
W = torch.randn(m, n, generator=torch.Generator().manual_seed(seed))
W, _ = torch.qr(W)
low_rank_ratio = 0.01 # make this to 0 to see weird effects
gen = torch.Generator().manual_seed(seed + 1)
G = torch.randn((m, 4), generator=gen) @ torch.randn((4, n), generator=gen) + torch.randn((m, n), generator=gen) * low_rank_ratio
W_init = W.clone()
datas = []
for update_iters in [1, 2, 4, 8, 16, 32, 64]:
for max_iter in [1, 2, 3, 4, 5, 6, 8, 14, 20, 30, 40, 50]:
W = W_init.clone()
for _ in range(update_iters):
eta = 1/8
W = stiefel_update(W, G, eta, max_iter=max_iter)
I_approx = W.T @ W
print(f"Error for max_iter={max_iter}, update_iters={update_iters}: {torch.norm(I_approx - torch.eye(n))}")
datas.append((max_iter, torch.norm(I_approx - torch.eye(n)), update_iters))
import pandas as pd
df = pd.DataFrame(datas, columns=["max_iter", "error", "update_iters"])
df.to_csv("stiefel_update_error.csv", index=False)
import matplotlib.pyplot as plt
for update_iters in [1, 2, 4, 8, 16, 32, 64]:
df_sub = df[df["update_iters"] == update_iters]
plt.plot(df_sub["max_iter"], df_sub["error"], label=f"update_iters={update_iters}")
plt.xlabel("max_iter")
plt.ylabel("error")
plt.yscale("log")
plt.legend()
plt.savefig("stiefel_update_error.png")
plt.show()
@cloneofsimo
Copy link
Author

without low rank (ratioo = 0.1)
image

with low rank (ratio = 0.0
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment