Probabilistic PCA in PyTorch with missing data. See the very nice JMLR paper of Ilin and Raiko in 2010.
import torch
from torch import nn
from import trange
class PPCA(nn.Module):
def __init__(self, d, c):
self.d = d
self.c = c
self.m = nn.Parameter(torch.zeros(d), requires_grad=False)
self.w = nn.Parameter(torch.zeros(d, c), requires_grad=False)
self.vy = nn.Parameter(torch.ones(()), requires_grad=False)
self.register_buffer("Ic", torch.eye(c))
def initialize_from(self, Y, O):
with torch.no_grad():
m = torch.nan_to_num(torch.nanmean(Y, dim=0))
Yc = torch.nan_to_num(Y - m)
u, s, vh = torch.linalg.svd(Yc, full_matrices=False)
vh = vh[:self.c]
s = s[:self.c]
def forward(self, Y, O=None):
if O is None:
O = torch.isfinite(Y)
O_ =
Y = torch.nan_to_num(Y)
Oww = torch.einsum("ni,ik,il->nkl", O_, self.w, self.w)
Sigma_xn = self.vy * torch.linalg.inv(self.vy * self.Ic[None] + Oww)
xbarn = torch.einsum("nlk,ni,ik,ni->nl", Sigma_xn / self.vy, O_, self.w, Y - self.m)
return xbarn, Sigma_xn
def m_step(self, Y, O, xbarn, Sigma_xn):
O_ =
m = torch.nanmean(torch.addmm(Y, xbarn, self.w.T, alpha=-1), axis=0)
# xxTSigmax = xbarn[:, :, None] * xbarn[:, None, :] + Sigma_xn
xxTSigmax = torch.baddbmm(Sigma_xn, xbarn[:, :, None], xbarn[:, None, :])
xxTSigmax = O_.T @ xxTSigmax.reshape(-1, self.c * self.c)
xxTSigmax = xxTSigmax.reshape(-1, self.c, self.c)
Ym = torch.nan_to_num(Y - m)
xbarym = torch.einsum("ni,nk,ni->ik", O_, xbarn, Ym)
w = torch.linalg.solve(xxTSigmax, xbarym)
wSw = torch.einsum("ik,nkl,il->ni", w, Sigma_xn, w)
vy = torch.mean(torch.square(torch.addmm(Ym, xbarn, w.T, alpha=-1)[O]) + wSw[O])
return m, w, vy
def fit_transform(self, Y, O=None, max_iter=100, show_progress=True, atol=1e-3):
if O is None:
O = torch.isfinite(Y)
self.initialize_from(Y, O)
pbar = trange(max_iter)
for i in pbar:
xbarn, Sigma_xn = self(Y, O)
m, w, vy = self.m_step(Y, O, xbarn, Sigma_xn)
dw = torch.max(torch.abs(w - self.w))
# dw = torch.mean(torch.square(w - self.w))
if torch.isnan(dw).any() or dw < atol:
return xbarn, Sigma_xn
import torch
from torch import nn
from import trange
class VBPCA(torch.nn.Module):
def __init__(self, d, c):
self.d = d
self.c = c
self.mbar = nn.Parameter(torch.zeros(d), requires_grad=False)
self.mtilde = nn.Parameter(torch.ones(d), requires_grad=False)
self.wbar = nn.Parameter(torch.zeros(d, c), requires_grad=False)
eyes_dcc = torch.eye(c)[None].repeat(d, 1, 1)
self.Sigmaw = nn.Parameter(eyes_dcc, requires_grad=False)
self.vy = nn.Parameter(torch.ones(()), requires_grad=False)
self.vm = nn.Parameter(torch.ones(()), requires_grad=False)
self.vwk = nn.Parameter(torch.ones(c), requires_grad=False)
self.register_buffer("Ic", torch.eye(c))
def initialize_from(self, Y, O):
with torch.no_grad():
m = torch.nan_to_num(torch.nanmean(Y, dim=0))
Yc = torch.nan_to_num(Y - m)
u, s, vh = torch.linalg.svd(Yc, full_matrices=False)
vh = vh[:self.c]
s = s[:self.c]
def forward(self, Y, O=None):
if O is None:
O = torch.isfinite(Y)
O_ =
Y = torch.nan_to_num(Y)
OwwSw = torch.baddbmm(self.Sigmaw, self.wbar[:, :, None], self.wbar[:, None, :])
OwwSw = (O_ @ OwwSw.reshape(self.d, self.c * self.c)).reshape(O.shape[0], self.c, self.c)
Sigma_xn = self.vy * torch.linalg.inv(self.vy * self.Ic[None] + OwwSw)
xbarn = torch.einsum("nlk,ni,ik,ni->nl", Sigma_xn / self.vy, O_, self.wbar, Y - self.mbar)
return xbarn, Sigma_xn
def m_step(self, Y, O, xbarn, Sigma_xn):
O_ =
Oicard = O_.sum(0)
# -- M
vm_OvvO = torch.where(Oicard > 0, self.vm / (Oicard * (self.vm + self.vy / Oicard)), self.vm / self.vy)
# mbar
mbar = vm_OvvO * torch.nansum(torch.addmm(Y, xbarn, self.wbar.T, alpha=-1), axis=0)
# mtilde
mtilde = self.vy * vm_OvvO
# -- W
# xxTSigmax = xbarn[:, :, None] * xbarn[:, None, :] + Sigma_xn
xxTSigmax = torch.baddbmm(Sigma_xn, xbarn[:, :, None], xbarn[:, None, :])
xxTSigmax = (O_.T @ xxTSigmax.reshape(-1, self.c * self.c)).reshape(-1, self.c, self.c)
# Sigmaw
Sigmaw = self.vy * torch.linalg.inv(self.vy * torch.diag(self.vwk) + xxTSigmax)
Ym = torch.nan_to_num(Y - mbar)
xbarym = torch.einsum("ni,nk,ni->ik", O_, xbarn, Ym)
# wbar
wbar = torch.bmm(Sigmaw, xbarym[:, :, None])[:, :, 0] / self.vy
# -- vs
wSxw = torch.einsum("ik,nkl,il->ni", wbar, Sigma_xn, wbar)
xSwx = torch.einsum("nk,ikl,nl->ni", xbarn, Sigmaw, xbarn)
trSS = torch.einsum("nkl,ikl->ni", Sigma_xn, Sigmaw)
# vy
vy = torch.mean(
torch.square(torch.addmm(Ym, xbarn, wbar.T, alpha=-1)[O])
+ wSxw[O]
+ xSwx[O]
+ trSS[O]
# vwk
vwk = torch.mean(torch.square(wbar) + torch.diagonal(Sigmaw, dim1=-2, dim2=-1), dim=0)
# vm
vm = torch.mean(torch.square(mbar) + mtilde, dim=0)
return mbar, mtilde, wbar, Sigmaw, vy, vwk, vm
def fit_transform(self, Y, O=None, max_iter=100, show_progress=True, atol=1e-3):
if O is None:
O = torch.isfinite(Y)
self.initialize_from(Y, O)
pbar = trange(max_iter)
for i in pbar:
xbarn, Sigma_xn = self(Y, O)
mbar, mtilde, wbar, Sigmaw, vy, vwk, vm = self.m_step(Y, O, xbarn, Sigma_xn)
dw = torch.max(torch.abs(wbar - self.wbar))
# dw = torch.mean(torch.square(w - self.w))
if torch.isnan(dw).any() or dw < atol:
return xbarn, Sigma_xn
def predictive_dists(self, Ynew, Onew=None):
if Onew is None:
Onew = torch.isfinite(Ynew)
# posteriors for xs, holding everything fixed
# not sure how kosher this is.
xbarnew, Sigma_xnew = self(Ynew, Onew)
# posterior predictive means
Ynewhat = xbarnew @ self.wbar.T + self.mbar
# posterior predictive variance
wSxw = torch.einsum("ik,nkl,il->ni", self.wbar, Sigma_xnew, self.wbar)
xSwx = torch.einsum("nk,ikl,nl->ni", xbarnew, self.Sigmaw, xbarnew)
trSS = Sigma_xnew.reshape(-1, self.c ** 2) @ self.Sigmaw.reshape(self.d, self.c ** 2).T
Sigma_Ynew = torch.diag_embed(wSxw + xSwx + trSS + (self.mtilde + self.vy))
return Ynewhat, Sigma_Ynew
# dists = torch.distributions.MultivariateNormal(
# Ynewhat, covariance_matrix=Sigma_Ynew
# )
# return dists
