Skip to content

Instantly share code, notes, and snippets.

@cwindolf
Last active April 11, 2024 15:18
Show Gist options
  • Save cwindolf/8b635d75c11fb61eb21935e1cbaf8b6a to your computer and use it in GitHub Desktop.
Save cwindolf/8b635d75c11fb61eb21935e1cbaf8b6a to your computer and use it in GitHub Desktop.
Probabilistic PCA in PyTorch with missing data. See the very nice JMLR paper of Ilin and Raiko in 2010.
import torch
from torch import nn
from tqdm.auto import trange
class PPCA(nn.Module):
def __init__(self, d, c):
super().__init__()
self.d = d
self.c = c
self.m = nn.Parameter(torch.zeros(d), requires_grad=False)
self.w = nn.Parameter(torch.zeros(d, c), requires_grad=False)
self.vy = nn.Parameter(torch.ones(()), requires_grad=False)
self.register_buffer("Ic", torch.eye(c))
def initialize_from(self, Y, O):
with torch.no_grad():
m = torch.nan_to_num(torch.nanmean(Y, dim=0))
self.m.copy_(m)
Yc = torch.nan_to_num(Y - m)
u, s, vh = torch.linalg.svd(Yc, full_matrices=False)
vh = vh[:self.c]
s = s[:self.c]
self.w.copy_(vh.T)
def forward(self, Y, O=None):
if O is None:
O = torch.isfinite(Y)
O_ = O.to(self.w)
Y = torch.nan_to_num(Y)
Oww = torch.einsum("ni,ik,il->nkl", O_, self.w, self.w)
Sigma_xn = self.vy * torch.linalg.inv(self.vy * self.Ic[None] + Oww)
xbarn = torch.einsum("nlk,ni,ik,ni->nl", Sigma_xn / self.vy, O_, self.w, Y - self.m)
return xbarn, Sigma_xn
def m_step(self, Y, O, xbarn, Sigma_xn):
O_ = O.to(self.w)
m = torch.nanmean(torch.addmm(Y, xbarn, self.w.T, alpha=-1), axis=0)
# xxTSigmax = xbarn[:, :, None] * xbarn[:, None, :] + Sigma_xn
xxTSigmax = torch.baddbmm(Sigma_xn, xbarn[:, :, None], xbarn[:, None, :])
xxTSigmax = O_.T @ xxTSigmax.reshape(-1, self.c * self.c)
xxTSigmax = xxTSigmax.reshape(-1, self.c, self.c)
Ym = torch.nan_to_num(Y - m)
xbarym = torch.einsum("ni,nk,ni->ik", O_, xbarn, Ym)
w = torch.linalg.solve(xxTSigmax, xbarym)
wSw = torch.einsum("ik,nkl,il->ni", w, Sigma_xn, w)
vy = torch.mean(torch.square(torch.addmm(Ym, xbarn, w.T, alpha=-1)[O]) + wSw[O])
return m, w, vy
def fit_transform(self, Y, O=None, max_iter=100, show_progress=True, atol=1e-3):
if O is None:
O = torch.isfinite(Y)
self.initialize_from(Y, O)
pbar = trange(max_iter)
for i in pbar:
xbarn, Sigma_xn = self(Y, O)
m, w, vy = self.m_step(Y, O, xbarn, Sigma_xn)
dw = torch.max(torch.abs(w - self.w))
# dw = torch.mean(torch.square(w - self.w))
pbar.set_description(f"dw={dw.numpy(force=True):0.5f}")
if torch.isnan(dw).any() or dw < atol:
break
self.m.copy_(m)
self.w.copy_(w)
self.vy.copy_(vy)
return xbarn, Sigma_xn
import torch
from torch import nn
from tqdm.auto import trange
class VBPCA(torch.nn.Module):
def __init__(self, d, c):
super().__init__()
self.d = d
self.c = c
self.mbar = nn.Parameter(torch.zeros(d), requires_grad=False)
self.mtilde = nn.Parameter(torch.ones(d), requires_grad=False)
self.wbar = nn.Parameter(torch.zeros(d, c), requires_grad=False)
eyes_dcc = torch.eye(c)[None].repeat(d, 1, 1)
self.Sigmaw = nn.Parameter(eyes_dcc, requires_grad=False)
self.vy = nn.Parameter(torch.ones(()), requires_grad=False)
self.vm = nn.Parameter(torch.ones(()), requires_grad=False)
self.vwk = nn.Parameter(torch.ones(c), requires_grad=False)
self.register_buffer("Ic", torch.eye(c))
def initialize_from(self, Y, O):
with torch.no_grad():
m = torch.nan_to_num(torch.nanmean(Y, dim=0))
self.mbar.copy_(m)
Yc = torch.nan_to_num(Y - m)
u, s, vh = torch.linalg.svd(Yc, full_matrices=False)
vh = vh[:self.c]
s = s[:self.c]
self.wbar.copy_(vh.T)
def forward(self, Y, O=None):
if O is None:
O = torch.isfinite(Y)
O_ = O.to(self.wbar)
Y = torch.nan_to_num(Y)
OwwSw = torch.baddbmm(self.Sigmaw, self.wbar[:, :, None], self.wbar[:, None, :])
OwwSw = (O_ @ OwwSw.reshape(self.d, self.c * self.c)).reshape(O.shape[0], self.c, self.c)
Sigma_xn = self.vy * torch.linalg.inv(self.vy * self.Ic[None] + OwwSw)
xbarn = torch.einsum("nlk,ni,ik,ni->nl", Sigma_xn / self.vy, O_, self.wbar, Y - self.mbar)
return xbarn, Sigma_xn
def m_step(self, Y, O, xbarn, Sigma_xn):
O_ = O.to(self.wbar)
Oicard = O_.sum(0)
# -- M
vm_OvvO = torch.where(Oicard > 0, self.vm / (Oicard * (self.vm + self.vy / Oicard)), self.vm / self.vy)
# mbar
mbar = vm_OvvO * torch.nansum(torch.addmm(Y, xbarn, self.wbar.T, alpha=-1), axis=0)
# mtilde
mtilde = self.vy * vm_OvvO
# -- W
# xxTSigmax = xbarn[:, :, None] * xbarn[:, None, :] + Sigma_xn
xxTSigmax = torch.baddbmm(Sigma_xn, xbarn[:, :, None], xbarn[:, None, :])
xxTSigmax = (O_.T @ xxTSigmax.reshape(-1, self.c * self.c)).reshape(-1, self.c, self.c)
# Sigmaw
Sigmaw = self.vy * torch.linalg.inv(self.vy * torch.diag(self.vwk) + xxTSigmax)
Ym = torch.nan_to_num(Y - mbar)
xbarym = torch.einsum("ni,nk,ni->ik", O_, xbarn, Ym)
# wbar
wbar = torch.bmm(Sigmaw, xbarym[:, :, None])[:, :, 0] / self.vy
# -- vs
wSxw = torch.einsum("ik,nkl,il->ni", wbar, Sigma_xn, wbar)
xSwx = torch.einsum("nk,ikl,nl->ni", xbarn, Sigmaw, xbarn)
trSS = torch.einsum("nkl,ikl->ni", Sigma_xn, Sigmaw)
# vy
vy = torch.mean(
torch.square(torch.addmm(Ym, xbarn, wbar.T, alpha=-1)[O])
+ wSxw[O]
+ xSwx[O]
+ trSS[O]
)
# vwk
vwk = torch.mean(torch.square(wbar) + torch.diagonal(Sigmaw, dim1=-2, dim2=-1), dim=0)
# vm
vm = torch.mean(torch.square(mbar) + mtilde, dim=0)
return mbar, mtilde, wbar, Sigmaw, vy, vwk, vm
def fit_transform(self, Y, O=None, max_iter=100, show_progress=True, atol=1e-3):
if O is None:
O = torch.isfinite(Y)
self.initialize_from(Y, O)
pbar = trange(max_iter)
for i in pbar:
xbarn, Sigma_xn = self(Y, O)
mbar, mtilde, wbar, Sigmaw, vy, vwk, vm = self.m_step(Y, O, xbarn, Sigma_xn)
dw = torch.max(torch.abs(wbar - self.wbar))
# dw = torch.mean(torch.square(w - self.w))
pbar.set_description(f"dw={dw.numpy(force=True):0.5f}")
if torch.isnan(dw).any() or dw < atol:
break
self.mbar.copy_(mbar)
self.mtilde.copy_(mtilde)
self.wbar.copy_(wbar)
self.Sigmaw.copy_(Sigmaw)
self.vy.copy_(vy)
self.vwk.copy_(vwk)
self.vm.copy_(vm)
return xbarn, Sigma_xn
def predictive_dists(self, Ynew, Onew=None):
if Onew is None:
Onew = torch.isfinite(Ynew)
# posteriors for xs, holding everything fixed
# not sure how kosher this is.
xbarnew, Sigma_xnew = self(Ynew, Onew)
# posterior predictive means
Ynewhat = xbarnew @ self.wbar.T + self.mbar
# posterior predictive variance
wSxw = torch.einsum("ik,nkl,il->ni", self.wbar, Sigma_xnew, self.wbar)
xSwx = torch.einsum("nk,ikl,nl->ni", xbarnew, self.Sigmaw, xbarnew)
trSS = Sigma_xnew.reshape(-1, self.c ** 2) @ self.Sigmaw.reshape(self.d, self.c ** 2).T
Sigma_Ynew = torch.diag_embed(wSxw + xSwx + trSS + (self.mtilde + self.vy))
return Ynewhat, Sigma_Ynew
# dists = torch.distributions.MultivariateNormal(
# Ynewhat, covariance_matrix=Sigma_Ynew
# )
# return dists
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment