Last active
January 2, 2023 08:34
-
-
Save norabelrose/c4149f45f28fb02b1b71dc55dd465285 to your computer and use it in GitHub Desktop.
Relaxed non-negative SVD
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch as th | |
| import torch.nn.functional as F | |
| # Sinkhorn-Knopp algorithm for projecting onto doubly stochastic matrices | |
| def sinkhorn_knopp(A: th.Tensor, max_iter: int = 20): | |
| A = A.clone() | |
| for _ in range(max_iter): | |
| A /= A.sum(dim=1, keepdim=True) | |
| A /= A.sum(dim=0, keepdim=True) | |
| return A | |
| class RNNSVD(th.nn.Module): | |
| """Generalized SVD where U is constrained to be doubly stochastic (and therefore non-negative). | |
| This is a convex relaxation of the naive objective where U is constrained to be both orthogonal | |
| and non-negative (and therefore a permutation matrix). That's probably NP-hard or something, but | |
| the convex hull of the permutation matrices is the set of doubly stochastic matrices (Birkhoff's theorem). | |
| """ | |
| A: th.Tensor | |
| log_U: th.nn.Parameter | |
| log_S: th.nn.Parameter | |
| Vh: th.nn.Parameter | |
| def __init__(self, A: th.Tensor): | |
| super().__init__() | |
| m, n = A.shape | |
| k = min(m, n) | |
| u, s, vh = th.linalg.svd(A, full_matrices=False) | |
| u = sinkhorn_knopp(u.clamp_min(1e-6)) | |
| s *= th.norm(A) / th.norm(u @ s.diag() @ vh) | |
| self.register_buffer("A", A) | |
| self.log_U = th.nn.Parameter(u.log().contiguous()) | |
| self.log_S = th.nn.Parameter(s.log()) | |
| self.Vh = th.nn.Parameter(vh.contiguous()) | |
| th.nn.utils.parametrizations.orthogonal(self, "Vh") | |
| def fit(self, max_iter: int = 10_000) -> float: | |
| opt = th.optim.LBFGS( | |
| self.parameters(), | |
| line_search_fn="strong_wolfe", | |
| max_iter=max_iter, | |
| ) | |
| loss = th.inf | |
| def closure(): | |
| nonlocal loss | |
| opt.zero_grad() | |
| A_hat = self.U @ self.S.diag() @ self.Vh | |
| loss = th.square(self.A - A_hat).sum() / min(self.A.shape) | |
| loss.backward() | |
| print(f"{loss=}") | |
| return loss | |
| opt.step(closure) # type: ignore | |
| # Sort the singular values | |
| self.log_S.data, idx = self.log_S.data.sort(descending=True) | |
| self.log_U.data = self.log_U.data[:, idx] | |
| self.Vh.data = self.Vh.data[idx, :] | |
| self.requires_grad_(False) | |
| return float(loss) | |
| @property | |
| def U(self) -> th.Tensor: | |
| return sinkhorn_knopp(self.log_U.softmax(0)) | |
| @property | |
| def S(self) -> th.Tensor: | |
| return self.log_S.exp() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment