Last active
November 6, 2023 13:06
-
-
Save jainraj/4b5481d7d865b66a193cece011d61934 to your computer and use it in GitHub Desktop.
Variational LSTM implemented based on Gal & Ghahramani's paper - https://arxiv.org/abs/1512.05287
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pytorch_lightning.utilities.seed import seed_everything | |
from torch import device as Device, dtype as DType, Tensor | |
from torch.nn import Parameter, LSTM, init | |
from typing import List, Tuple, Optional | |
import torch.nn as nn | |
import torch | |
def reverse(lst: List[Tensor]) -> List[Tensor]: | |
return lst[::-1] | |
def init_stacked_rnn(layer, model): | |
layers = nn.ModuleList() | |
dirs = 2 if model.bidirectional else 1 | |
layers.append(layer(input_size=model.input_size, hidden_size=model.hidden_size, dropout=model.dropout)) | |
for _ in range(model.num_layers - 1): | |
layers.append(layer(input_size=model.hidden_size * dirs, hidden_size=model.hidden_size, dropout=model.dropout)) | |
return layers | |
def reorder(x): | |
return x.view([-1] + list(x.shape[2:])) | |
def flatten_states(states: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]: | |
"""We have list[(h, c)]. We want to convert to (stack_list[h], stack_list[c])""" | |
hs, cs = [], [] | |
for state in states: | |
hs.append(state[0]) | |
cs.append(state[1]) | |
return torch.stack(hs), torch.stack(cs) | |
def double_flatten_states(states: List[List[Tuple[Tensor, Tensor]]]): | |
first_flatten: List[Tuple[Tensor, Tensor]] = [] | |
for inner in states: # inner: List[Tuple[Tensor, Tensor]] | |
first_flatten.append(flatten_states(inner)) | |
second_flatten = flatten_states(first_flatten) | |
return reorder(second_flatten[0]), reorder(second_flatten[1]) | |
class LSTMCell(nn.Module): | |
def __init__(self, input_size: int, hidden_size: int): | |
super().__init__() | |
self.input_size = input_size | |
self.hidden_size = hidden_size | |
self.weight_ih = Parameter(torch.Tensor(4 * hidden_size, input_size)) | |
self.weight_hh = Parameter(torch.Tensor(4 * hidden_size, hidden_size)) | |
self.bias_ih = Parameter(torch.Tensor(4 * hidden_size)) | |
self.bias_hh = Parameter(torch.Tensor(4 * hidden_size)) | |
def forward(self, | |
inpt: Tensor, # N, I | |
state: Tuple[Tensor, Tensor], # All N, H | |
input_masks: Tensor, # 4, N, I | |
hidden_masks: Tensor, # 4, N, H | |
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: # All N, H | |
prev_h, prev_c = state | |
input_weights = self.weight_ih.reshape(4, self.hidden_size, self.input_size).permute((0, 2, 1)) # 4, I, H | |
input_biases = self.bias_ih.reshape(4, 1, self.hidden_size) # 4, 1, H | |
hidden_weights = self.weight_hh.reshape(4, self.hidden_size, self.hidden_size).permute((0, 2, 1)) # 4, H, H | |
hidden_biases = self.bias_hh.reshape(4, 1, self.hidden_size) # 4, 1, H | |
masked_input = torch.mul(inpt, input_masks) # 4, N, I | |
masked_hidden = torch.mul(prev_h, hidden_masks) # 4, N, H | |
gates = torch.baddbmm(input_biases, masked_input, input_weights) + \ | |
torch.baddbmm(hidden_biases, masked_hidden, hidden_weights) # 4, N, H | |
input_gate, forget_gate, candidate_c, output_gate = gates.unbind(0) # All N, H | |
input_gate = torch.sigmoid(input_gate) # N, H | |
forget_gate = torch.sigmoid(forget_gate) # N, H | |
candidate_c = torch.tanh(candidate_c) # N, H | |
output_gate = torch.sigmoid(output_gate) # N, H | |
new_c = (forget_gate * prev_c) + (input_gate * candidate_c) # N, H | |
new_h = output_gate * torch.tanh(new_c) # N, H | |
return new_h, (new_h, new_c) | |
class CoreLSTMLayer(nn.Module): | |
def __init__(self, dropout: float, **cell_kwargs): | |
super().__init__() | |
assert 0 <= dropout < 1 | |
self.dropout = dropout | |
self.cell = LSTMCell(**cell_kwargs) | |
def get_dropout_masks(self, batch_size: int, input_size: int, hidden_size: int, device: Device, dtype: DType): | |
input_masks = torch.ones((4, batch_size, input_size), device=device, dtype=dtype) # 4, N, I | |
hidden_masks = torch.ones((4, batch_size, hidden_size), device=device, dtype=dtype) # 4, N, H | |
if self.training and self.dropout > 0: | |
keep_prob = 1 - self.dropout | |
# sample dropout once for the sequence & scale (inverted dropout method) | |
return torch.bernoulli(input_masks * keep_prob) / keep_prob, torch.bernoulli(hidden_masks * keep_prob) / keep_prob | |
else: | |
# no dropout & no scaling | |
return input_masks, hidden_masks | |
def forward(self, | |
inputs: List[Tensor], # N, I of length T | |
state: Tuple[Tensor, Tensor] # All N, H | |
) -> Tuple[List[Tensor], Tuple[Tensor, Tensor]]: # N, T, H & All N, H | |
batch_size = inputs[0].size(dim=0) | |
device, dtype = inputs[0].device, inputs[0].dtype | |
input_size, hidden_size = self.cell.input_size, self.cell.hidden_size | |
input_masks, hidden_masks = self.get_dropout_masks(batch_size, input_size, hidden_size, device, dtype) | |
outputs: List[Tensor] = [] | |
for i in range(len(inputs)): | |
out, state = self.cell(inputs[i], state, input_masks, hidden_masks) | |
outputs += [out] | |
return outputs, state | |
class LSTMLayer(nn.Module): | |
def __init__(self, **core_kwargs): | |
super().__init__() | |
self.core_layer = CoreLSTMLayer(**core_kwargs) | |
def forward(self, | |
inpt: Tensor, # N, T, I | |
state: Tuple[Tensor, Tensor] # All N, H | |
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: # N, T, H & All N, H | |
inputs = inpt.unbind(1) # split on time axis | |
outputs, state = self.core_layer(inputs, state) | |
return torch.stack(outputs, dim=1), state | |
class ReverseLSTMLayer(nn.Module): | |
def __init__(self, **core_kwargs): | |
super().__init__() | |
self.core_layer = CoreLSTMLayer(**core_kwargs) | |
def forward(self, | |
inpt: Tensor, # N, T, I | |
state: Tuple[Tensor, Tensor] # All N, H | |
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: # N, T, H & All N, H | |
inputs = reverse(inpt.unbind(1)) # split on time axis | |
outputs, state = self.core_layer(inputs, state) | |
return torch.stack(reverse(outputs), dim=1), state | |
class BiDirLSTMLayer(nn.Module): | |
def __init__(self, **core_kwargs): | |
super().__init__() | |
self.directions = nn.ModuleList( | |
[ | |
LSTMLayer(**core_kwargs), | |
ReverseLSTMLayer(**core_kwargs), | |
] | |
) | |
def forward(self, | |
inpt: Tensor, # N, T, I | |
states: Tuple[Tensor, Tensor] # Both 2, N, H | |
) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]: # N, T, 2H & All N, H | |
outputs: List[Tensor] = [] | |
output_states: List[Tuple[Tensor, Tensor]] = [] | |
for i, direction in enumerate(self.directions): | |
state = states[0][i, :, :], states[1][i, :, :] | |
out, out_state = direction(inpt, state) | |
outputs += [out] | |
output_states += [out_state] | |
return torch.cat(outputs, -1), output_states | |
class StackedUniLSTM(nn.Module): | |
def __init__(self, input_size, hidden_size, num_layers, dropout): | |
super().__init__() | |
self.input_size = input_size | |
self.hidden_size = hidden_size | |
self.num_layers = num_layers | |
self.bidirectional = False | |
self.dropout = dropout | |
self.layers = init_stacked_rnn(LSTMLayer, self) | |
def forward(self, | |
inpt: Tensor, # N, T, I | |
init_states: Optional[Tuple[Tensor, Tensor]] = None # Both L, N, H | |
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: # N, T, H & All L, N, H | |
output_states: List[Tuple[Tensor, Tensor]] = [] | |
output = inpt | |
batch_size = inpt.size(dim=0) | |
if init_states is None: | |
device, dtype = inpt.device, inpt.dtype | |
zeros = torch.zeros((self.num_layers, batch_size, self.hidden_size), device=device, dtype=dtype) | |
init_states = (zeros, zeros) | |
for i, rnn_layer in enumerate(self.layers): | |
init_state = init_states[0][i, :, :], init_states[1][i, :, :] | |
output, out_state = rnn_layer(output, init_state) | |
output_states += [out_state] | |
flattened_output_states = flatten_states(output_states) | |
return output, flattened_output_states | |
class StackedBiLSTM(nn.Module): | |
def __init__(self, input_size, hidden_size, num_layers, dropout): | |
super().__init__() | |
self.input_size = input_size | |
self.hidden_size = hidden_size | |
self.num_layers = num_layers | |
self.bidirectional = True | |
self.dropout = dropout | |
self.layers = init_stacked_rnn(BiDirLSTMLayer, self) | |
def forward(self, | |
inpt: Tensor, # N, T, I | |
states: Optional[Tuple[Tensor, Tensor]] = None # Both 2L, N, H | |
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: # N, T, 2H & All 2L, N, H | |
output_states: List[List[Tuple[Tensor, Tensor]]] = [] | |
output = inpt | |
batch_size = inpt.size(dim=0) | |
if states is None: | |
device, dtype = inpt.device, inpt.dtype | |
zeros = torch.zeros((2 * self.num_layers, batch_size, self.hidden_size), device=device, dtype=dtype) | |
states = (zeros, zeros) | |
for i, rnn_layer in enumerate(self.layers): | |
init_state = states[0][2 * i:2 * i + 2, :, :], states[1][2 * i:2 * i + 2, :, :] | |
output, out_state = rnn_layer(output, init_state) | |
output_states += [out_state] | |
flattened_output_states = double_flatten_states(output_states) | |
return output, flattened_output_states | |
def variational_lstm( | |
input_size: int, | |
hidden_size: int, | |
num_layers: int, | |
bidirectional: bool, | |
dropout: float, | |
): | |
""" | |
Returns a ScriptModule implementation of Variational LSTM | |
(Ref: https://papers.nips.cc/paper_files/paper/2016/hash/076a0c97d09cf1a0ec3e19c7f2529f2b-Abstract.html) | |
When dropout=0, it behaves like PyTorch's LSTM (with dropout=0) | |
Dropout probability is tied i.e., one dropout value for all layers | |
""" | |
stack_type = StackedBiLSTM if bidirectional else StackedUniLSTM | |
module = stack_type( | |
input_size=input_size, | |
hidden_size=hidden_size, | |
num_layers=num_layers, | |
dropout=dropout, | |
) | |
return torch.jit.script(module) | |
def init_lstm(lstm): | |
lstm.double() | |
seed_everything(0) | |
for name, param in lstm.named_parameters(): | |
if 'weight_ih' in name: | |
init.xavier_uniform_(param) | |
elif 'weight_hh' in name: | |
init.orthogonal_(param) | |
elif 'bias_ih' in name: | |
init.zeros_(param) | |
elif 'bias_hh' in name: | |
param.data = torch.tensor([0] * 3 + [1] * 3 + [0] * 3 * 2, | |
dtype=torch.double) | |
return lstm | |
def get_pyt_lstm(num_layers, bidirectional): | |
pyt_lstm = LSTM(input_size=7, hidden_size=3, num_layers=num_layers, | |
bidirectional=bidirectional, dropout=0, batch_first=True) | |
pyt_lstm = init_lstm(pyt_lstm) | |
return pyt_lstm | |
def get_var_lstm(num_layers, bidirectional, dropout=0.): | |
var_lstm = variational_lstm(input_size=7, hidden_size=3, num_layers=num_layers, | |
bidirectional=bidirectional, dropout=dropout) | |
var_lstm = init_lstm(var_lstm) | |
return var_lstm | |
def test_lstm(): | |
"""Test Variational LSTM with PyTorch's LSTM when dropout = 0.""" | |
for num_layers in [1, 2, 3]: | |
for bidirectional in [True, False]: | |
pyt_lstm = get_pyt_lstm(num_layers, bidirectional) | |
var_lstm = get_var_lstm(num_layers, bidirectional) | |
seed_everything(0) | |
inp = torch.randn(64, 4, 7, dtype=torch.double) | |
nl = 2 * num_layers if bidirectional else num_layers | |
state = (torch.randn(nl, 64, 3, dtype=torch.double), | |
torch.randn(nl, 64, 3, dtype=torch.double)) | |
pyt_out, (pyt_h, pyt_c) = pyt_lstm(inp, state) | |
var_out, (var_h, var_c) = var_lstm(inp, state) | |
assert torch.allclose(pyt_out, var_out, rtol=0, atol=1e-15) | |
assert torch.allclose(pyt_h, var_h, rtol=0, atol=1e-15) | |
assert torch.allclose(pyt_c, var_c, rtol=0, atol=1e-15) | |
if __name__ == '__main__': | |
test_lstm() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment