Last active
April 11, 2024 15:18
-
-
Save cwindolf/8b635d75c11fb61eb21935e1cbaf8b6a to your computer and use it in GitHub Desktop.
Probabilistic PCA in PyTorch with missing data. See the very nice JMLR paper of Ilin and Raiko in 2010.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
from tqdm.auto import trange | |
class PPCA(nn.Module): | |
def __init__(self, d, c): | |
super().__init__() | |
self.d = d | |
self.c = c | |
self.m = nn.Parameter(torch.zeros(d), requires_grad=False) | |
self.w = nn.Parameter(torch.zeros(d, c), requires_grad=False) | |
self.vy = nn.Parameter(torch.ones(()), requires_grad=False) | |
self.register_buffer("Ic", torch.eye(c)) | |
def initialize_from(self, Y, O): | |
with torch.no_grad(): | |
m = torch.nan_to_num(torch.nanmean(Y, dim=0)) | |
self.m.copy_(m) | |
Yc = torch.nan_to_num(Y - m) | |
u, s, vh = torch.linalg.svd(Yc, full_matrices=False) | |
vh = vh[:self.c] | |
s = s[:self.c] | |
self.w.copy_(vh.T) | |
def forward(self, Y, O=None): | |
if O is None: | |
O = torch.isfinite(Y) | |
O_ = O.to(self.w) | |
Y = torch.nan_to_num(Y) | |
Oww = torch.einsum("ni,ik,il->nkl", O_, self.w, self.w) | |
Sigma_xn = self.vy * torch.linalg.inv(self.vy * self.Ic[None] + Oww) | |
xbarn = torch.einsum("nlk,ni,ik,ni->nl", Sigma_xn / self.vy, O_, self.w, Y - self.m) | |
return xbarn, Sigma_xn | |
def m_step(self, Y, O, xbarn, Sigma_xn): | |
O_ = O.to(self.w) | |
m = torch.nanmean(torch.addmm(Y, xbarn, self.w.T, alpha=-1), axis=0) | |
# xxTSigmax = xbarn[:, :, None] * xbarn[:, None, :] + Sigma_xn | |
xxTSigmax = torch.baddbmm(Sigma_xn, xbarn[:, :, None], xbarn[:, None, :]) | |
xxTSigmax = O_.T @ xxTSigmax.reshape(-1, self.c * self.c) | |
xxTSigmax = xxTSigmax.reshape(-1, self.c, self.c) | |
Ym = torch.nan_to_num(Y - m) | |
xbarym = torch.einsum("ni,nk,ni->ik", O_, xbarn, Ym) | |
w = torch.linalg.solve(xxTSigmax, xbarym) | |
wSw = torch.einsum("ik,nkl,il->ni", w, Sigma_xn, w) | |
vy = torch.mean(torch.square(torch.addmm(Ym, xbarn, w.T, alpha=-1)[O]) + wSw[O]) | |
return m, w, vy | |
def fit_transform(self, Y, O=None, max_iter=100, show_progress=True, atol=1e-3): | |
if O is None: | |
O = torch.isfinite(Y) | |
self.initialize_from(Y, O) | |
pbar = trange(max_iter) | |
for i in pbar: | |
xbarn, Sigma_xn = self(Y, O) | |
m, w, vy = self.m_step(Y, O, xbarn, Sigma_xn) | |
dw = torch.max(torch.abs(w - self.w)) | |
# dw = torch.mean(torch.square(w - self.w)) | |
pbar.set_description(f"dw={dw.numpy(force=True):0.5f}") | |
if torch.isnan(dw).any() or dw < atol: | |
break | |
self.m.copy_(m) | |
self.w.copy_(w) | |
self.vy.copy_(vy) | |
return xbarn, Sigma_xn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
from tqdm.auto import trange | |
class VBPCA(torch.nn.Module): | |
def __init__(self, d, c): | |
super().__init__() | |
self.d = d | |
self.c = c | |
self.mbar = nn.Parameter(torch.zeros(d), requires_grad=False) | |
self.mtilde = nn.Parameter(torch.ones(d), requires_grad=False) | |
self.wbar = nn.Parameter(torch.zeros(d, c), requires_grad=False) | |
eyes_dcc = torch.eye(c)[None].repeat(d, 1, 1) | |
self.Sigmaw = nn.Parameter(eyes_dcc, requires_grad=False) | |
self.vy = nn.Parameter(torch.ones(()), requires_grad=False) | |
self.vm = nn.Parameter(torch.ones(()), requires_grad=False) | |
self.vwk = nn.Parameter(torch.ones(c), requires_grad=False) | |
self.register_buffer("Ic", torch.eye(c)) | |
def initialize_from(self, Y, O): | |
with torch.no_grad(): | |
m = torch.nan_to_num(torch.nanmean(Y, dim=0)) | |
self.mbar.copy_(m) | |
Yc = torch.nan_to_num(Y - m) | |
u, s, vh = torch.linalg.svd(Yc, full_matrices=False) | |
vh = vh[:self.c] | |
s = s[:self.c] | |
self.wbar.copy_(vh.T) | |
def forward(self, Y, O=None): | |
if O is None: | |
O = torch.isfinite(Y) | |
O_ = O.to(self.wbar) | |
Y = torch.nan_to_num(Y) | |
OwwSw = torch.baddbmm(self.Sigmaw, self.wbar[:, :, None], self.wbar[:, None, :]) | |
OwwSw = (O_ @ OwwSw.reshape(self.d, self.c * self.c)).reshape(O.shape[0], self.c, self.c) | |
Sigma_xn = self.vy * torch.linalg.inv(self.vy * self.Ic[None] + OwwSw) | |
xbarn = torch.einsum("nlk,ni,ik,ni->nl", Sigma_xn / self.vy, O_, self.wbar, Y - self.mbar) | |
return xbarn, Sigma_xn | |
def m_step(self, Y, O, xbarn, Sigma_xn): | |
O_ = O.to(self.wbar) | |
Oicard = O_.sum(0) | |
# -- M | |
vm_OvvO = torch.where(Oicard > 0, self.vm / (Oicard * (self.vm + self.vy / Oicard)), self.vm / self.vy) | |
# mbar | |
mbar = vm_OvvO * torch.nansum(torch.addmm(Y, xbarn, self.wbar.T, alpha=-1), axis=0) | |
# mtilde | |
mtilde = self.vy * vm_OvvO | |
# -- W | |
# xxTSigmax = xbarn[:, :, None] * xbarn[:, None, :] + Sigma_xn | |
xxTSigmax = torch.baddbmm(Sigma_xn, xbarn[:, :, None], xbarn[:, None, :]) | |
xxTSigmax = (O_.T @ xxTSigmax.reshape(-1, self.c * self.c)).reshape(-1, self.c, self.c) | |
# Sigmaw | |
Sigmaw = self.vy * torch.linalg.inv(self.vy * torch.diag(self.vwk) + xxTSigmax) | |
Ym = torch.nan_to_num(Y - mbar) | |
xbarym = torch.einsum("ni,nk,ni->ik", O_, xbarn, Ym) | |
# wbar | |
wbar = torch.bmm(Sigmaw, xbarym[:, :, None])[:, :, 0] / self.vy | |
# -- vs | |
wSxw = torch.einsum("ik,nkl,il->ni", wbar, Sigma_xn, wbar) | |
xSwx = torch.einsum("nk,ikl,nl->ni", xbarn, Sigmaw, xbarn) | |
trSS = torch.einsum("nkl,ikl->ni", Sigma_xn, Sigmaw) | |
# vy | |
vy = torch.mean( | |
torch.square(torch.addmm(Ym, xbarn, wbar.T, alpha=-1)[O]) | |
+ wSxw[O] | |
+ xSwx[O] | |
+ trSS[O] | |
) | |
# vwk | |
vwk = torch.mean(torch.square(wbar) + torch.diagonal(Sigmaw, dim1=-2, dim2=-1), dim=0) | |
# vm | |
vm = torch.mean(torch.square(mbar) + mtilde, dim=0) | |
return mbar, mtilde, wbar, Sigmaw, vy, vwk, vm | |
def fit_transform(self, Y, O=None, max_iter=100, show_progress=True, atol=1e-3): | |
if O is None: | |
O = torch.isfinite(Y) | |
self.initialize_from(Y, O) | |
pbar = trange(max_iter) | |
for i in pbar: | |
xbarn, Sigma_xn = self(Y, O) | |
mbar, mtilde, wbar, Sigmaw, vy, vwk, vm = self.m_step(Y, O, xbarn, Sigma_xn) | |
dw = torch.max(torch.abs(wbar - self.wbar)) | |
# dw = torch.mean(torch.square(w - self.w)) | |
pbar.set_description(f"dw={dw.numpy(force=True):0.5f}") | |
if torch.isnan(dw).any() or dw < atol: | |
break | |
self.mbar.copy_(mbar) | |
self.mtilde.copy_(mtilde) | |
self.wbar.copy_(wbar) | |
self.Sigmaw.copy_(Sigmaw) | |
self.vy.copy_(vy) | |
self.vwk.copy_(vwk) | |
self.vm.copy_(vm) | |
return xbarn, Sigma_xn | |
def predictive_dists(self, Ynew, Onew=None): | |
if Onew is None: | |
Onew = torch.isfinite(Ynew) | |
# posteriors for xs, holding everything fixed | |
# not sure how kosher this is. | |
xbarnew, Sigma_xnew = self(Ynew, Onew) | |
# posterior predictive means | |
Ynewhat = xbarnew @ self.wbar.T + self.mbar | |
# posterior predictive variance | |
wSxw = torch.einsum("ik,nkl,il->ni", self.wbar, Sigma_xnew, self.wbar) | |
xSwx = torch.einsum("nk,ikl,nl->ni", xbarnew, self.Sigmaw, xbarnew) | |
trSS = Sigma_xnew.reshape(-1, self.c ** 2) @ self.Sigmaw.reshape(self.d, self.c ** 2).T | |
Sigma_Ynew = torch.diag_embed(wSxw + xSwx + trSS + (self.mtilde + self.vy)) | |
return Ynewhat, Sigma_Ynew | |
# dists = torch.distributions.MultivariateNormal( | |
# Ynewhat, covariance_matrix=Sigma_Ynew | |
# ) | |
# return dists |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment