Created
June 5, 2025 17:36
-
-
Save evanatyourservice/d1092439f6f79135634c3f606b443adb to your computer and use it in GitHub Desktop.
QDWH
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
def _qdwh_qr_step(u, params): | |
a_minus_e_by_sqrt_c, sqrt_c, e = params | |
M, N = u.shape | |
eye_n = torch.eye(N, dtype=u.dtype, device=u.device) | |
y = torch.cat((sqrt_c * u, eye_n), dim=0) | |
q, _ = torch.linalg.qr(y, mode='reduced') | |
q1, q2 = q[:M, :], q[M:, :] | |
return e * u + a_minus_e_by_sqrt_c * (q1 @ q2.mT) | |
def _qdwh_chol_step(u, params): | |
a_minus_e, c, e = params | |
M, N = u.shape | |
x = c * (u.mT @ u) + torch.eye(N, dtype=u.dtype, device=u.device) | |
L = torch.linalg.cholesky(x) | |
Y = torch.linalg.solve_triangular(L, u.mT, upper=False, left=True) | |
Z = torch.linalg.solve_triangular(L.mT, Y, upper=True, left=True) | |
return e * u + a_minus_e * Z.mT | |
def qdwh(x, max_iterations=10, eps=None): | |
M, N = x.shape | |
if M < N: | |
raise ValueError(f"Input matrix must have M >= N, but got shape {x.shape}") | |
if eps is None: | |
eps = torch.finfo(x.dtype).eps | |
one_norm = torch.linalg.norm(x, ord=1) | |
inf_norm = torch.linalg.norm(x, ord=float('inf')) | |
alpha_inverse = torch.rsqrt(one_norm) * torch.rsqrt(inf_norm) if one_norm > 0 else torch.tensor(1.0, dtype=x.dtype, device=x.device) | |
u = x * alpha_inverse.to(x.dtype) | |
l = torch.tensor(eps, dtype=x.dtype, device=x.device) | |
tol_l = 10.0 * eps / 2.0 | |
tol_norm = tol_l**(1/3) | |
qr_params_list = [] | |
chol_params_list = [] | |
CHOLESKY_CUTOFF = 100 | |
k = 0 | |
while l + tol_l < 1 and k < max_iterations: | |
k += 1 | |
l2 = l * l | |
safe_l2 = torch.where(l2 > 1e-12, l2, torch.tensor(1e-12, dtype=l2.dtype, device=l2.device)) | |
term = 4 * (1 / safe_l2 - 1) / safe_l2 | |
dd = torch.copysign(torch.pow(torch.abs(term), 1/3), term) | |
sqd = torch.sqrt(torch.relu(1.0 + dd)) | |
inner_sqrt_arg = 2 - dd + 2 * (2 - l2) / (l2 * sqd) if l2 * sqd != 0 else 2 - dd | |
a = sqd + torch.sqrt(torch.relu(inner_sqrt_arg)) | |
b = (a - 1)**2 / 4 | |
c = a + b - 1 | |
if l2 == 0: | |
l_next = l | |
else: | |
safe_denominator = torch.where(1 + c * l2 != 0, 1 + c * l2, torch.tensor(1.0, dtype=l2.dtype, device=l2.device)) | |
l_next = l * (a + b * l2) / safe_denominator | |
if not torch.isfinite(l_next): | |
break | |
l = l_next | |
e_param = b / c | |
a_minus_e = a - e_param | |
if c > CHOLESKY_CUTOFF: | |
sqrt_c = torch.sqrt(c) | |
a_minus_e_by_sqrt_c = a_minus_e / sqrt_c if sqrt_c != 0 else torch.tensor(0.0, dtype=x.dtype, device=x.device) | |
qr_params_list.append((a_minus_e_by_sqrt_c, sqrt_c, e_param)) | |
else: | |
chol_params_list.append((a_minus_e, c, e_param)) | |
if qr_params_list: | |
qr_params = torch.stack([torch.stack(p) for p in qr_params_list]).detach() | |
else: | |
qr_params = torch.empty((0, 3), dtype=x.dtype, device=x.device) | |
if chol_params_list: | |
chol_params = torch.stack([torch.stack(p) for p in chol_params_list]).detach() | |
else: | |
chol_params = torch.empty((0, 3), dtype=x.dtype, device=x.device) | |
num_iters = 0 | |
is_converged = False | |
for i in range(len(qr_params)): | |
if num_iters >= max_iterations: break | |
u = _qdwh_qr_step(u, qr_params[i]) | |
num_iters += 1 | |
for i in range(len(chol_params)): | |
if num_iters >= max_iterations: break | |
u_prev = u | |
u = _qdwh_chol_step(u, chol_params[i]) | |
num_iters += 1 | |
with torch.no_grad(): | |
diff = torch.linalg.norm(u - u_prev) | |
if diff <= tol_norm: | |
is_converged = True | |
break | |
if not is_converged: | |
halley_params = (torch.tensor(8/3, dtype=x.dtype, device=x.device), | |
torch.tensor(3.0, dtype=x.dtype, device=x.device), | |
torch.tensor(1/3, dtype=x.dtype, device=x.device)) | |
while num_iters < max_iterations: | |
u_prev = u | |
u = _qdwh_chol_step(u, halley_params) | |
num_iters += 1 | |
with torch.no_grad(): | |
diff = torch.linalg.norm(u - u_prev) | |
if diff <= tol_norm: | |
is_converged = True | |
break | |
u = 1.5 * u - 0.5 * u @ (u.mT @ u) | |
h = u.mT @ x | |
h = (h + h.mT) / 2.0 | |
if num_iters >= max_iterations and not is_converged: | |
print(f"Warning: QDWH did not converge within {max_iterations} iterations.") | |
is_converged = False | |
return u, h, num_iters, is_converged | |
if __name__ == "__main__": | |
torch.manual_seed(42) | |
M = 1024 | |
x_test = torch.randn(M, M, dtype=torch.float32) | |
print(f"Testing QDWH on a random {M}x{M} matrix (float32)...") | |
u, h, iters, converged = qdwh(x_test, max_iterations=20) | |
print(f"Converged: {converged} in {iters} iterations.") | |
if not converged: | |
print("Warning: Test case did not converge.") | |
gram = u.mT @ u | |
identity = torch.eye(M, dtype=gram.dtype, device=gram.device) | |
error_norm = torch.linalg.norm(gram - identity).item() | |
print(f"orthogonality error ||U^T U - I||: {error_norm:.3e}") | |
recon_err = torch.linalg.norm(x_test - (u @ h)).item() | |
print(f"reconstruction error ||X - U H||: {recon_err:.3e}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment