Skip to content

Instantly share code, notes, and snippets.

@andres-fr
Last active October 23, 2025 19:10
Show Gist options
  • Select an option

  • Save andres-fr/b5e9167d25cc4d9d799249a2de4f8ebd to your computer and use it in GitHub Desktop.

Select an option

Save andres-fr/b5e9167d25cc4d9d799249a2de4f8ebd to your computer and use it in GitHub Desktop.
Combining XDiag with Hutch++
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Combining XDiag with Hutch++.
Copyright (C) 2025 aferro (ORCID: 0000-0003-3830-3595)
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from skerch.synthmat import RandomLordMatrix
from skerch.linops import SumLinOp, CompositeLinOp
from skerch.measurements import RademacherNoiseLinOp
import torch
import matplotlib.pyplot as plt
# ##############################################################################
# # HELPERS
# ##############################################################################
def relerr(x, xhat):
"""Relative error."""
return ((x - xhat).norm() / x.norm()).item()
def hutch(lop, mop):
"""Plain Girard-Hutchinson diagonal estimator.
:param mop: Measurement linop. A tall Rademacher/Phase matrix
:returns: Estimate of ``diag(lop)``
"""
result = (mop.conj() * (lop @ mop)).mean(1)
return result
def hutchpp(lop, mop_defl, mop_gh):
"""Hutch++: Rank-deflated Girard-Hutchinson diagonal estimator.
:param mop_defl: Deflation measurement linop. Should be tall and noisy.
:param mop_gh: Measurement linop. A tall Rademacher/Phase matrix.
:returns: estimates ``(d_top, d_defl)`` where top corresponds to the
top-rank part that has been deflated, and defl to the remaining GH.
"""
Q, R = torch.linalg.qr(lop @ mop_defl)
QhA = Q.H @ lop
D_top = (Q.T * QhA).sum(0)
defl = SumLinOp(
[
("A", True, lop),
("Q QhA", False, CompositeLinOp([("Q", Q), ("QhA", QhA)])),
]
)
D_defl = hutch(defl, mop_gh)
return D_top, D_defl
def xdiag(lop, mop):
"""XDiag: Exchangeable diagonal estimator.
:param mop: Measurement linop. A tall Rademacher/Phase matrix
:returns: estimates ``(d_psi, d_xd)`` where psi corresponds to the
Psi-deflated part, and d_xd to the remaining component, corresponding
to the recycled-exchanged estimate of the deflated component.
"""
Q, R = torch.linalg.qr(lop @ mop)
Rinv = torch.linalg.inv(R)
Sh_k = Rinv / (Rinv.norm(dim=1, keepdim=True) * len(Rinv) ** 0.5)
Psi = -Sh_k.H @ Sh_k
Psi[range(len(Rinv)), range(len(Rinv))] += 1
PsiQhA = Psi @ Q.H @ lop
D_psi = (Q.T * PsiQhA).sum(0)
D_xd = ((Q @ Sh_k.H) * (mop.conj() * (Sh_k @ R).diag())).sum(1)
return D_psi, D_xd
def xdiagpp(lop, mop_x, mop_gh):
"""XDiag++: XDiag followed by GH.
:param mop_x: XDiag measurement linop. A tall Rademacher/Phase matrix.
:param mop_gh: GH measurement linop. A tall Rademacher/Phase matrix.
:returns: estimates ``(d_psi, d_xpp)`` where psi corresponds to the
Psi-deflated part, and d_xpp to the remaining xdiag component,
combined with the extra GH estimation on top.
"""
Q, R = torch.linalg.qr(lop @ mop_x)
Rinv = torch.linalg.inv(R)
Sh_k = Rinv / (Rinv.norm(dim=1, keepdim=True) * len(Rinv) ** 0.5)
Psi = -Sh_k.H @ Sh_k
Psi[range(len(Rinv)), range(len(Rinv))] += 1
PsiQhA = Psi @ Q.H @ lop
D_psi = (Q.T * PsiQhA).sum(0)
D_xd = ((Q @ Sh_k.H) * (mop_x.conj() * (Sh_k @ R).diag())).sum(1)
#
psi_defl = SumLinOp(
[
("A", True, lop),
(
"Q Psi QhA",
False,
CompositeLinOp([("Q", Q), ("Psi QhA", PsiQhA)]),
),
]
)
D_xd_defl = hutch(psi_defl, mop_gh)
D_xpp = (len(R) * D_xd + mop_gh.shape[1] * D_xd_defl) / (
len(R) + mop_gh.shape[1]
)
return D_psi, D_xpp
# ##############################################################################
# # XX
# ##############################################################################
if __name__ == "__main__":
# globals
DIMS, SVD_DECAY, DIAG_RATIO = 1000, 0.001, 3
SEED = 12345
DTYPE, DEVICE = torch.complex128, "cpu"
MEAS_DIMS, MEAS_EXTRA = 200, 400
# sample ground truth matrix and fetch diagonal for testing
A = RandomLordMatrix.exp(
(DIMS, DIMS),
rank=1,
decay=SVD_DECAY,
diag_ratio=DIAG_RATIO,
symmetric=False,
seed=SEED,
dtype=DTYPE,
device=DEVICE,
)[0]
D = A.diag()
# sample measurement noisy matrices
M1 = RademacherNoiseLinOp(
(DIMS, MEAS_DIMS), seed=SEED + DIMS, blocksize=MEAS_DIMS
).to_matrix(DTYPE, DEVICE)
M2 = RademacherNoiseLinOp(
(DIMS, MEAS_DIMS), seed=SEED + 2 * DIMS, blocksize=MEAS_DIMS
).to_matrix(DTYPE, DEVICE)
M = torch.hstack([M1, M2])
M_extra = RademacherNoiseLinOp(
(DIMS, MEAS_EXTRA), seed=SEED + 3 * DIMS, blocksize=MEAS_DIMS
).to_matrix(DTYPE, DEVICE)
# PLAIN GH:
D_gh = hutch(A, M)
err_gh = relerr(D, D_gh)
print("\n[PLAIN GH]")
print("Relative error of GH:", err_gh)
print("Converges towards correct result")
# DEFLATED GH (HUTCH++)
D_top, D_defl = hutchpp(A, M1, M2)
err_top = relerr(D, D_top)
err_hpp = relerr(D, D_top + D_defl)
print("\n[DEFLATED GH (Hutch++)]")
print("Relative error of defl:", err_top)
print("Relative error of defl->GH:", err_hpp)
print("Splitting into deflation->GH helps")
# JUST DEFLATION
D_top_full, D_defl_null = hutchpp(A, M, M)
err_top_full = relerr(D, D_top_full)
print("\n[JUST DEFLATION]")
print("Relative error of just defl:", err_top_full)
print("Investing all measurements into deflation also helps...")
# DEFLATED, RECYCLED GH
err_recy = relerr(D, D_top_full + D_defl_null)
print("\n[DEFLATED, RECYCLED Hutch++]")
print("Relative error of defl->GH_biased:", err_recy)
print("... but then recycling GH does nothing")
# XDIAG HALF
D_psi, D_xd = xdiag(A, M1)
err_xdiag_half = relerr(D, D_psi + D_xd)
print("\n[XDiag (half)]")
print("Relative error of XDiag (half):", err_xdiag_half)
print("Now recycling deflation is as effective as new GH measurements!")
# XDIAG
D_psi, D_xd = xdiag(A, M)
err_xdiag = relerr(D, D_psi + D_xd)
print("\n[XDiag]")
print("Relative error of XDiag:", err_xdiag)
print("And for same number of measurements, XDiag dominates")
# XDIAG++
D_psi, D_xpp = xdiagpp(A, M, M_extra)
err_xdiagpp = relerr(D, D_psi + D_xpp)
print("\n[XDiag++]")
print("Relative error of XDiag->GH:", err_xdiagpp)
print("Same memory as XDiag, but more GH measurements for further gains")
# PLOT MATRIX
# fig, ax = plt.subplots(figsize=(5, 5))
fig, (ax1, ax2) = plt.subplots(
2, 1, figsize=(6, 7.5), gridspec_kw={"height_ratios": [4, 0.7]}
)
Aplot = A[:200, :200].real
max_abs = abs(Aplot).max()
ax1.imshow(Aplot, cmap="bwr", vmin=-max_abs, vmax=max_abs)
ax2.plot(torch.linalg.svdvals(A))
ax1.set_xticks([])
ax1.set_yticks([])
ax2.set_xticks([])
ax2.set_yticks([])
fig.suptitle("Fragment of sampled matrix and its singular values")
fig.tight_layout()
outpath = "lord_matrix_.png"
fig.savefig(outpath, dpi=250)
print("Saved fig to", outpath)
# PLOT ERRORS
errors = [
err_gh,
err_top,
err_hpp,
err_top_full,
err_recy,
err_xdiag_half,
err_xdiag,
err_xdiagpp,
]
labels = [
f"GH ({M.shape[1]} meas.)",
f"Deflation ({M1.shape[1]} meas.)",
f"Hutch++ ({M1.shape[1]} + {M2.shape[1]} meas.)",
f"Deflation ({M.shape[1]} meas.)",
f"Recycled Hutch++ ({M.shape[1]} meas.)",
f"XDiag ({M1.shape[1]} meas.)",
f"XDiag ({M.shape[1]} meas.)",
f"XDiag++ ({M.shape[1]} + {M_extra.shape[1]} meas.)",
]
fig, ax = plt.subplots(figsize=(8, 7))
ax.bar(labels, errors, width=0.5, edgecolor="darkblue", linewidth=1.5)
ax.grid(axis="y", alpha=0.7, zorder=-1)
plt.xticks(rotation=90)
ax.set_ylabel(r"$\frac{|| D - \hat{D} ||_2}{|| D ||_2}$", fontsize=18)
fig.suptitle(
f"Diagonal approximation error for a {(DIMS, DIMS)} matrix",
fontsize=14,
)
ax.set_ylim(0, 0.7)
fig.tight_layout() # Adjust layout to ensure labels fit
outpath = "diag_benchmark_.png"
fig.savefig(outpath, dpi=250)
print("Saved fig to", outpath)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment