Last active
July 2, 2021 12:55
-
-
Save yuq-1s/c190a5b01ffb284e7c823d07303c62cc to your computer and use it in GitHub Desktop.
HMM on GPU with pytorch in 100 lines
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
This proof-of-concept follows [PRML](https://www.microsoft.com/en-us/research/people/cmbishop/prml-book/)'s idea. | |
This code extends plain HMM in the way that it has different transition matrix and emission matrix on different features `xs`. | |
To get a normal HMM, you can set all `x` to the same. | |
`HMM.predict()` uses formula (13.44) in PRML, which considers the whole seen sequence of observation `y`s. | |
If you have no observed `y`s and only have `x`s, you can use `model.trans(x).view(T, N, self.H, self.H).softmax(dim=3)` as transition matrix to get predicted sequence. | |
`gamma` here represents posterior probability of hidden states. | |
''' | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
def my_div(a, b): | |
b += (b.abs() < 1e-8) * 1e-6 | |
return a / b | |
def alpha(a, T, E): | |
# T: [N, H, H], transition matrix from n-1 to n | |
# a: [N, H], last_alpha | |
# E: [N, H, C], emission matrix at step n | |
return E * torch.einsum('nj,njk->nk', a, T) | |
def alpha_hat_and_c(a, T, E): | |
# T: [N, H, H], transition matrix from n-1 to n | |
# a: [N, H], last_alpha_hat | |
# E: [N, H, C], emission matrix at step n | |
rhs = alpha(a, T, E) | |
return my_div(rhs, rhs.sum(dim=1).view(-1, 1)), rhs.sum(dim=1) | |
def beta_hat(b, T, E, c_t_plus_one): | |
rhs = beta(b, T, E) | |
return my_div(rhs, c_t_plus_one.view(-1, 1)) | |
def beta(b, T, E): | |
# T: transition matrix from n to n+1 | |
# b: next_beta | |
# E: emission matrix at step n+1 | |
return torch.einsum('nj,nij,nj->ni', b, T, E) | |
class HMM(nn.Module): | |
def __init__(self, H, C, D): | |
''' H: hidden dimension | |
C: emission dimension | |
D: feature dimension | |
''' | |
super().__init__() | |
self.H, self.C, self.D = H, C, D | |
self.trans = nn.Linear(D, H*H) | |
self.emit = nn.Linear(D, H*C) | |
# TODO: init as nn.Linear | |
self.init = nn.Parameter(torch.rand(H), requires_grad=True) | |
@property | |
def device(self): | |
return self.trans.weight.device | |
def forward(self, x, y): | |
''' x: [N, T, D] | |
y: [N, T] | |
''' | |
x, y = x.transpose(0, 1), y.transpose(0, 1) | |
T, N, D = x.shape | |
log_trans = self.trans(x).view(T, N, self.H, self.H).log_softmax(dim=3) | |
log_emit = self.emit(x).view(T, N, self.H, self.C).log_softmax(dim=3) | |
log_emit_prob = torch.einsum('tnhc,tnc->tnh', log_emit, F.one_hot(y, num_classes=self.C).float()) | |
with torch.no_grad(): | |
trans, emit_prob = log_trans.exp(), log_emit_prob.exp() | |
# gamma: [T, N, H]; xi: [T, N, H, H] | |
gamma, xi = self.gamma_and_xi_stable(trans, emit_prob) # Baum-Welch E-step | |
log_prob = torch.einsum('nh,h->', gamma[0], self.init.log_softmax(dim=0)) | |
if len(x) > 1: | |
log_prob += torch.einsum('tnjk,tnjk->', xi, log_trans[1:]) | |
log_prob += torch.einsum('tnk,tnk->', gamma, log_emit_prob) | |
return -log_prob / (N) | |
def alpha_hats(self, Ts, Es): | |
N = Ts.shape[1] | |
last_alpha_hat = torch.stack([self.init.softmax(dim=0)]*N) | |
alpha_hats, cs = [], [] | |
for Tr, Em in zip(Ts, Es): | |
last_alpha_hat, c = alpha_hat_and_c(last_alpha_hat, Tr, Em) | |
alpha_hats.append(last_alpha_hat) | |
cs.append(c) | |
return torch.stack(alpha_hats), torch.stack(cs) | |
def gamma_and_xi_stable(self, Trs, Ems): | |
alpha_hats, cs = self.alpha_hats(Trs, Ems) | |
beta_hats = self.beta_hats(cs, Trs, Ems).to(self.device) | |
xi = torch.einsum('tni,tnj,tnij,tnj,tn->tnij', alpha_hats[:-1], Ems[1:], Trs[1:], beta_hats[1:], my_div(1, cs[1:])) | |
return alpha_hats * beta_hats, xi | |
def beta_hats(self, c, Trs, Ems): | |
N = Trs.shape[1] | |
if len(Trs) == 1: | |
return torch.ones(1, N, self.H, device=self.device) | |
else: | |
future_betas = self.beta_hats(c[1:], Trs[1:], Ems[1:]) | |
next_beta = future_betas[0] | |
return torch.cat((beta_hat(next_beta, Trs[1], Ems[1], c[1]).view(1, N, self.H), future_betas)) | |
def predict(self, x, y): | |
N, T, D = x.shape | |
x, y = x.transpose(0, 1), y.transpose(0, 1) | |
trans = self.trans(x).view(T, N, self.H, self.H).softmax(dim=3) | |
emit = self.emit(x).view(T, N, self.H, self.C).softmax(dim=3) | |
emit_prob = torch.einsum('tnhc,tnc->tnh', emit, F.one_hot(y, num_classes=self.C).float()) | |
alpha_hats, _ = self.alpha_hats(trans, emit_prob) | |
predicted = [] | |
for Tr, Em, ah in zip(trans, emit, alpha_hats): | |
predicted.append(torch.einsum('ni,nij,njk->nk', ah, Tr, Em).argmax(dim=1)) | |
return torch.stack(predicted).transpose(0, 1) | |
if __name__ == '__main__': | |
H, C, T, D, N = 4, 3, 7, 10, 128 | |
device = 'cuda' | |
xs = torch.randn(N, T, D, device=device) | |
ys = torch.randint(C, (N, T), device=device) | |
lr, num_epochs = 5e-3, 100 | |
model = HMM(H, C, D).to(device) | |
model.train() | |
optim = torch.optim.Adam(lr=lr, params=model.parameters()) | |
for epoch in range(num_epochs): | |
optim.zero_grad() | |
loss = model(xs, ys) | |
loss.backward() | |
torch.nn.utils.clip_grad_norm_(model.parameters(), 5.) | |
optim.step() # Baum-Welch M-step | |
test_xs, test_ys = torch.randn(N, T, D, device=device), torch.randint(C, (N, T), device=device) | |
predicted_ys = model.predict(test_xs, test_ys) | |
from sklearn.metrics import accuracy_score | |
print(accuracy_score(test_ys.flatten().tolist(), predicted_ys.flatten().tolist())) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment