Skip to content

Instantly share code, notes, and snippets.

@norabelrose
Last active January 2, 2023 08:34
Show Gist options
  • Select an option

  • Save norabelrose/c4149f45f28fb02b1b71dc55dd465285 to your computer and use it in GitHub Desktop.

Select an option

Save norabelrose/c4149f45f28fb02b1b71dc55dd465285 to your computer and use it in GitHub Desktop.
Relaxed non-negative SVD
import torch as th
import torch.nn.functional as F
# Sinkhorn-Knopp algorithm for projecting onto doubly stochastic matrices
def sinkhorn_knopp(A: th.Tensor, max_iter: int = 20):
A = A.clone()
for _ in range(max_iter):
A /= A.sum(dim=1, keepdim=True)
A /= A.sum(dim=0, keepdim=True)
return A
class RNNSVD(th.nn.Module):
"""Generalized SVD where U is constrained to be doubly stochastic (and therefore non-negative).
This is a convex relaxation of the naive objective where U is constrained to be both orthogonal
and non-negative (and therefore a permutation matrix). That's probably NP-hard or something, but
the convex hull of the permutation matrices is the set of doubly stochastic matrices (Birkhoff's theorem).
"""
A: th.Tensor
log_U: th.nn.Parameter
log_S: th.nn.Parameter
Vh: th.nn.Parameter
def __init__(self, A: th.Tensor):
super().__init__()
m, n = A.shape
k = min(m, n)
u, s, vh = th.linalg.svd(A, full_matrices=False)
u = sinkhorn_knopp(u.clamp_min(1e-6))
s *= th.norm(A) / th.norm(u @ s.diag() @ vh)
self.register_buffer("A", A)
self.log_U = th.nn.Parameter(u.log().contiguous())
self.log_S = th.nn.Parameter(s.log())
self.Vh = th.nn.Parameter(vh.contiguous())
th.nn.utils.parametrizations.orthogonal(self, "Vh")
def fit(self, max_iter: int = 10_000) -> float:
opt = th.optim.LBFGS(
self.parameters(),
line_search_fn="strong_wolfe",
max_iter=max_iter,
)
loss = th.inf
def closure():
nonlocal loss
opt.zero_grad()
A_hat = self.U @ self.S.diag() @ self.Vh
loss = th.square(self.A - A_hat).sum() / min(self.A.shape)
loss.backward()
print(f"{loss=}")
return loss
opt.step(closure) # type: ignore
# Sort the singular values
self.log_S.data, idx = self.log_S.data.sort(descending=True)
self.log_U.data = self.log_U.data[:, idx]
self.Vh.data = self.Vh.data[idx, :]
self.requires_grad_(False)
return float(loss)
@property
def U(self) -> th.Tensor:
return sinkhorn_knopp(self.log_U.softmax(0))
@property
def S(self) -> th.Tensor:
return self.log_S.exp()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment