Last active
October 23, 2025 19:10
-
-
Save andres-fr/b5e9167d25cc4d9d799249a2de4f8ebd to your computer and use it in GitHub Desktop.
Combining XDiag with Hutch++
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| """Combining XDiag with Hutch++. | |
| Copyright (C) 2025 aferro (ORCID: 0000-0003-3830-3595) | |
| This program is free software: you can redistribute it and/or modify | |
| it under the terms of the GNU General Public License as published by | |
| the Free Software Foundation, either version 3 of the License, or | |
| (at your option) any later version. | |
| This program is distributed in the hope that it will be useful, | |
| but WITHOUT ANY WARRANTY; without even the implied warranty of | |
| MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
| GNU General Public License for more details. | |
| You should have received a copy of the GNU General Public License | |
| along with this program. If not, see <https://www.gnu.org/licenses/>. | |
| """ | |
| from skerch.synthmat import RandomLordMatrix | |
| from skerch.linops import SumLinOp, CompositeLinOp | |
| from skerch.measurements import RademacherNoiseLinOp | |
| import torch | |
| import matplotlib.pyplot as plt | |
| # ############################################################################## | |
| # # HELPERS | |
| # ############################################################################## | |
| def relerr(x, xhat): | |
| """Relative error.""" | |
| return ((x - xhat).norm() / x.norm()).item() | |
| def hutch(lop, mop): | |
| """Plain Girard-Hutchinson diagonal estimator. | |
| :param mop: Measurement linop. A tall Rademacher/Phase matrix | |
| :returns: Estimate of ``diag(lop)`` | |
| """ | |
| result = (mop.conj() * (lop @ mop)).mean(1) | |
| return result | |
| def hutchpp(lop, mop_defl, mop_gh): | |
| """Hutch++: Rank-deflated Girard-Hutchinson diagonal estimator. | |
| :param mop_defl: Deflation measurement linop. Should be tall and noisy. | |
| :param mop_gh: Measurement linop. A tall Rademacher/Phase matrix. | |
| :returns: estimates ``(d_top, d_defl)`` where top corresponds to the | |
| top-rank part that has been deflated, and defl to the remaining GH. | |
| """ | |
| Q, R = torch.linalg.qr(lop @ mop_defl) | |
| QhA = Q.H @ lop | |
| D_top = (Q.T * QhA).sum(0) | |
| defl = SumLinOp( | |
| [ | |
| ("A", True, lop), | |
| ("Q QhA", False, CompositeLinOp([("Q", Q), ("QhA", QhA)])), | |
| ] | |
| ) | |
| D_defl = hutch(defl, mop_gh) | |
| return D_top, D_defl | |
| def xdiag(lop, mop): | |
| """XDiag: Exchangeable diagonal estimator. | |
| :param mop: Measurement linop. A tall Rademacher/Phase matrix | |
| :returns: estimates ``(d_psi, d_xd)`` where psi corresponds to the | |
| Psi-deflated part, and d_xd to the remaining component, corresponding | |
| to the recycled-exchanged estimate of the deflated component. | |
| """ | |
| Q, R = torch.linalg.qr(lop @ mop) | |
| Rinv = torch.linalg.inv(R) | |
| Sh_k = Rinv / (Rinv.norm(dim=1, keepdim=True) * len(Rinv) ** 0.5) | |
| Psi = -Sh_k.H @ Sh_k | |
| Psi[range(len(Rinv)), range(len(Rinv))] += 1 | |
| PsiQhA = Psi @ Q.H @ lop | |
| D_psi = (Q.T * PsiQhA).sum(0) | |
| D_xd = ((Q @ Sh_k.H) * (mop.conj() * (Sh_k @ R).diag())).sum(1) | |
| return D_psi, D_xd | |
| def xdiagpp(lop, mop_x, mop_gh): | |
| """XDiag++: XDiag followed by GH. | |
| :param mop_x: XDiag measurement linop. A tall Rademacher/Phase matrix. | |
| :param mop_gh: GH measurement linop. A tall Rademacher/Phase matrix. | |
| :returns: estimates ``(d_psi, d_xpp)`` where psi corresponds to the | |
| Psi-deflated part, and d_xpp to the remaining xdiag component, | |
| combined with the extra GH estimation on top. | |
| """ | |
| Q, R = torch.linalg.qr(lop @ mop_x) | |
| Rinv = torch.linalg.inv(R) | |
| Sh_k = Rinv / (Rinv.norm(dim=1, keepdim=True) * len(Rinv) ** 0.5) | |
| Psi = -Sh_k.H @ Sh_k | |
| Psi[range(len(Rinv)), range(len(Rinv))] += 1 | |
| PsiQhA = Psi @ Q.H @ lop | |
| D_psi = (Q.T * PsiQhA).sum(0) | |
| D_xd = ((Q @ Sh_k.H) * (mop_x.conj() * (Sh_k @ R).diag())).sum(1) | |
| # | |
| psi_defl = SumLinOp( | |
| [ | |
| ("A", True, lop), | |
| ( | |
| "Q Psi QhA", | |
| False, | |
| CompositeLinOp([("Q", Q), ("Psi QhA", PsiQhA)]), | |
| ), | |
| ] | |
| ) | |
| D_xd_defl = hutch(psi_defl, mop_gh) | |
| D_xpp = (len(R) * D_xd + mop_gh.shape[1] * D_xd_defl) / ( | |
| len(R) + mop_gh.shape[1] | |
| ) | |
| return D_psi, D_xpp | |
| # ############################################################################## | |
| # # XX | |
| # ############################################################################## | |
| if __name__ == "__main__": | |
| # globals | |
| DIMS, SVD_DECAY, DIAG_RATIO = 1000, 0.001, 3 | |
| SEED = 12345 | |
| DTYPE, DEVICE = torch.complex128, "cpu" | |
| MEAS_DIMS, MEAS_EXTRA = 200, 400 | |
| # sample ground truth matrix and fetch diagonal for testing | |
| A = RandomLordMatrix.exp( | |
| (DIMS, DIMS), | |
| rank=1, | |
| decay=SVD_DECAY, | |
| diag_ratio=DIAG_RATIO, | |
| symmetric=False, | |
| seed=SEED, | |
| dtype=DTYPE, | |
| device=DEVICE, | |
| )[0] | |
| D = A.diag() | |
| # sample measurement noisy matrices | |
| M1 = RademacherNoiseLinOp( | |
| (DIMS, MEAS_DIMS), seed=SEED + DIMS, blocksize=MEAS_DIMS | |
| ).to_matrix(DTYPE, DEVICE) | |
| M2 = RademacherNoiseLinOp( | |
| (DIMS, MEAS_DIMS), seed=SEED + 2 * DIMS, blocksize=MEAS_DIMS | |
| ).to_matrix(DTYPE, DEVICE) | |
| M = torch.hstack([M1, M2]) | |
| M_extra = RademacherNoiseLinOp( | |
| (DIMS, MEAS_EXTRA), seed=SEED + 3 * DIMS, blocksize=MEAS_DIMS | |
| ).to_matrix(DTYPE, DEVICE) | |
| # PLAIN GH: | |
| D_gh = hutch(A, M) | |
| err_gh = relerr(D, D_gh) | |
| print("\n[PLAIN GH]") | |
| print("Relative error of GH:", err_gh) | |
| print("Converges towards correct result") | |
| # DEFLATED GH (HUTCH++) | |
| D_top, D_defl = hutchpp(A, M1, M2) | |
| err_top = relerr(D, D_top) | |
| err_hpp = relerr(D, D_top + D_defl) | |
| print("\n[DEFLATED GH (Hutch++)]") | |
| print("Relative error of defl:", err_top) | |
| print("Relative error of defl->GH:", err_hpp) | |
| print("Splitting into deflation->GH helps") | |
| # JUST DEFLATION | |
| D_top_full, D_defl_null = hutchpp(A, M, M) | |
| err_top_full = relerr(D, D_top_full) | |
| print("\n[JUST DEFLATION]") | |
| print("Relative error of just defl:", err_top_full) | |
| print("Investing all measurements into deflation also helps...") | |
| # DEFLATED, RECYCLED GH | |
| err_recy = relerr(D, D_top_full + D_defl_null) | |
| print("\n[DEFLATED, RECYCLED Hutch++]") | |
| print("Relative error of defl->GH_biased:", err_recy) | |
| print("... but then recycling GH does nothing") | |
| # XDIAG HALF | |
| D_psi, D_xd = xdiag(A, M1) | |
| err_xdiag_half = relerr(D, D_psi + D_xd) | |
| print("\n[XDiag (half)]") | |
| print("Relative error of XDiag (half):", err_xdiag_half) | |
| print("Now recycling deflation is as effective as new GH measurements!") | |
| # XDIAG | |
| D_psi, D_xd = xdiag(A, M) | |
| err_xdiag = relerr(D, D_psi + D_xd) | |
| print("\n[XDiag]") | |
| print("Relative error of XDiag:", err_xdiag) | |
| print("And for same number of measurements, XDiag dominates") | |
| # XDIAG++ | |
| D_psi, D_xpp = xdiagpp(A, M, M_extra) | |
| err_xdiagpp = relerr(D, D_psi + D_xpp) | |
| print("\n[XDiag++]") | |
| print("Relative error of XDiag->GH:", err_xdiagpp) | |
| print("Same memory as XDiag, but more GH measurements for further gains") | |
| # PLOT MATRIX | |
| # fig, ax = plt.subplots(figsize=(5, 5)) | |
| fig, (ax1, ax2) = plt.subplots( | |
| 2, 1, figsize=(6, 7.5), gridspec_kw={"height_ratios": [4, 0.7]} | |
| ) | |
| Aplot = A[:200, :200].real | |
| max_abs = abs(Aplot).max() | |
| ax1.imshow(Aplot, cmap="bwr", vmin=-max_abs, vmax=max_abs) | |
| ax2.plot(torch.linalg.svdvals(A)) | |
| ax1.set_xticks([]) | |
| ax1.set_yticks([]) | |
| ax2.set_xticks([]) | |
| ax2.set_yticks([]) | |
| fig.suptitle("Fragment of sampled matrix and its singular values") | |
| fig.tight_layout() | |
| outpath = "lord_matrix_.png" | |
| fig.savefig(outpath, dpi=250) | |
| print("Saved fig to", outpath) | |
| # PLOT ERRORS | |
| errors = [ | |
| err_gh, | |
| err_top, | |
| err_hpp, | |
| err_top_full, | |
| err_recy, | |
| err_xdiag_half, | |
| err_xdiag, | |
| err_xdiagpp, | |
| ] | |
| labels = [ | |
| f"GH ({M.shape[1]} meas.)", | |
| f"Deflation ({M1.shape[1]} meas.)", | |
| f"Hutch++ ({M1.shape[1]} + {M2.shape[1]} meas.)", | |
| f"Deflation ({M.shape[1]} meas.)", | |
| f"Recycled Hutch++ ({M.shape[1]} meas.)", | |
| f"XDiag ({M1.shape[1]} meas.)", | |
| f"XDiag ({M.shape[1]} meas.)", | |
| f"XDiag++ ({M.shape[1]} + {M_extra.shape[1]} meas.)", | |
| ] | |
| fig, ax = plt.subplots(figsize=(8, 7)) | |
| ax.bar(labels, errors, width=0.5, edgecolor="darkblue", linewidth=1.5) | |
| ax.grid(axis="y", alpha=0.7, zorder=-1) | |
| plt.xticks(rotation=90) | |
| ax.set_ylabel(r"$\frac{|| D - \hat{D} ||_2}{|| D ||_2}$", fontsize=18) | |
| fig.suptitle( | |
| f"Diagonal approximation error for a {(DIMS, DIMS)} matrix", | |
| fontsize=14, | |
| ) | |
| ax.set_ylim(0, 0.7) | |
| fig.tight_layout() # Adjust layout to ensure labels fit | |
| outpath = "diag_benchmark_.png" | |
| fig.savefig(outpath, dpi=250) | |
| print("Saved fig to", outpath) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment