Skip to content

Instantly share code, notes, and snippets.

@evanatyourservice
Created June 5, 2025 17:36
Show Gist options
  • Save evanatyourservice/d1092439f6f79135634c3f606b443adb to your computer and use it in GitHub Desktop.
Save evanatyourservice/d1092439f6f79135634c3f606b443adb to your computer and use it in GitHub Desktop.
QDWH
import torch
def _qdwh_qr_step(u, params):
a_minus_e_by_sqrt_c, sqrt_c, e = params
M, N = u.shape
eye_n = torch.eye(N, dtype=u.dtype, device=u.device)
y = torch.cat((sqrt_c * u, eye_n), dim=0)
q, _ = torch.linalg.qr(y, mode='reduced')
q1, q2 = q[:M, :], q[M:, :]
return e * u + a_minus_e_by_sqrt_c * (q1 @ q2.mT)
def _qdwh_chol_step(u, params):
a_minus_e, c, e = params
M, N = u.shape
x = c * (u.mT @ u) + torch.eye(N, dtype=u.dtype, device=u.device)
L = torch.linalg.cholesky(x)
Y = torch.linalg.solve_triangular(L, u.mT, upper=False, left=True)
Z = torch.linalg.solve_triangular(L.mT, Y, upper=True, left=True)
return e * u + a_minus_e * Z.mT
def qdwh(x, max_iterations=10, eps=None):
M, N = x.shape
if M < N:
raise ValueError(f"Input matrix must have M >= N, but got shape {x.shape}")
if eps is None:
eps = torch.finfo(x.dtype).eps
one_norm = torch.linalg.norm(x, ord=1)
inf_norm = torch.linalg.norm(x, ord=float('inf'))
alpha_inverse = torch.rsqrt(one_norm) * torch.rsqrt(inf_norm) if one_norm > 0 else torch.tensor(1.0, dtype=x.dtype, device=x.device)
u = x * alpha_inverse.to(x.dtype)
l = torch.tensor(eps, dtype=x.dtype, device=x.device)
tol_l = 10.0 * eps / 2.0
tol_norm = tol_l**(1/3)
qr_params_list = []
chol_params_list = []
CHOLESKY_CUTOFF = 100
k = 0
while l + tol_l < 1 and k < max_iterations:
k += 1
l2 = l * l
safe_l2 = torch.where(l2 > 1e-12, l2, torch.tensor(1e-12, dtype=l2.dtype, device=l2.device))
term = 4 * (1 / safe_l2 - 1) / safe_l2
dd = torch.copysign(torch.pow(torch.abs(term), 1/3), term)
sqd = torch.sqrt(torch.relu(1.0 + dd))
inner_sqrt_arg = 2 - dd + 2 * (2 - l2) / (l2 * sqd) if l2 * sqd != 0 else 2 - dd
a = sqd + torch.sqrt(torch.relu(inner_sqrt_arg))
b = (a - 1)**2 / 4
c = a + b - 1
if l2 == 0:
l_next = l
else:
safe_denominator = torch.where(1 + c * l2 != 0, 1 + c * l2, torch.tensor(1.0, dtype=l2.dtype, device=l2.device))
l_next = l * (a + b * l2) / safe_denominator
if not torch.isfinite(l_next):
break
l = l_next
e_param = b / c
a_minus_e = a - e_param
if c > CHOLESKY_CUTOFF:
sqrt_c = torch.sqrt(c)
a_minus_e_by_sqrt_c = a_minus_e / sqrt_c if sqrt_c != 0 else torch.tensor(0.0, dtype=x.dtype, device=x.device)
qr_params_list.append((a_minus_e_by_sqrt_c, sqrt_c, e_param))
else:
chol_params_list.append((a_minus_e, c, e_param))
if qr_params_list:
qr_params = torch.stack([torch.stack(p) for p in qr_params_list]).detach()
else:
qr_params = torch.empty((0, 3), dtype=x.dtype, device=x.device)
if chol_params_list:
chol_params = torch.stack([torch.stack(p) for p in chol_params_list]).detach()
else:
chol_params = torch.empty((0, 3), dtype=x.dtype, device=x.device)
num_iters = 0
is_converged = False
for i in range(len(qr_params)):
if num_iters >= max_iterations: break
u = _qdwh_qr_step(u, qr_params[i])
num_iters += 1
for i in range(len(chol_params)):
if num_iters >= max_iterations: break
u_prev = u
u = _qdwh_chol_step(u, chol_params[i])
num_iters += 1
with torch.no_grad():
diff = torch.linalg.norm(u - u_prev)
if diff <= tol_norm:
is_converged = True
break
if not is_converged:
halley_params = (torch.tensor(8/3, dtype=x.dtype, device=x.device),
torch.tensor(3.0, dtype=x.dtype, device=x.device),
torch.tensor(1/3, dtype=x.dtype, device=x.device))
while num_iters < max_iterations:
u_prev = u
u = _qdwh_chol_step(u, halley_params)
num_iters += 1
with torch.no_grad():
diff = torch.linalg.norm(u - u_prev)
if diff <= tol_norm:
is_converged = True
break
u = 1.5 * u - 0.5 * u @ (u.mT @ u)
h = u.mT @ x
h = (h + h.mT) / 2.0
if num_iters >= max_iterations and not is_converged:
print(f"Warning: QDWH did not converge within {max_iterations} iterations.")
is_converged = False
return u, h, num_iters, is_converged
if __name__ == "__main__":
torch.manual_seed(42)
M = 1024
x_test = torch.randn(M, M, dtype=torch.float32)
print(f"Testing QDWH on a random {M}x{M} matrix (float32)...")
u, h, iters, converged = qdwh(x_test, max_iterations=20)
print(f"Converged: {converged} in {iters} iterations.")
if not converged:
print("Warning: Test case did not converge.")
gram = u.mT @ u
identity = torch.eye(M, dtype=gram.dtype, device=gram.device)
error_norm = torch.linalg.norm(gram - identity).item()
print(f"orthogonality error ||U^T U - I||: {error_norm:.3e}")
recon_err = torch.linalg.norm(x_test - (u @ h)).item()
print(f"reconstruction error ||X - U H||: {recon_err:.3e}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment