Skip to content

Instantly share code, notes, and snippets.

@evanatyourservice
Created June 5, 2025 18:02
Show Gist options
  • Save evanatyourservice/22ad13878ed3643f0cb32162781083e6 to your computer and use it in GitHub Desktop.
Save evanatyourservice/22ad13878ed3643f0cb32162781083e6 to your computer and use it in GitHub Desktop.
batched sqrt inverses using newton schulz
import torch
def compute_H_inv_cubic(A, num_iters=10):
X = torch.eye(A.shape[-1], dtype=A.dtype, device=A.device).expand(A.shape)
for _ in range(num_iters):
X_update = torch.einsum('...ij,...jk,...kl,...lm->...im', X, A, X, X)
X = 1.5 * X - 0.5 * X_update
return X
def compute_H_inv_quintic(A, num_iters=5):
I = torch.eye(A.shape[-1], dtype=A.dtype, device=A.device).expand(A.shape)
X = I.clone()
for _ in range(num_iters):
Y = torch.einsum('...ij,...jk,...kl->...il', A, X, X)
X = X @ (1.875 * I - 1.25 * Y + 0.375 * (Y @ Y))
return X
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment