Skip to content

Instantly share code, notes, and snippets.

@jainraj
Last active November 6, 2023 13:06
Show Gist options
  • Save jainraj/4b5481d7d865b66a193cece011d61934 to your computer and use it in GitHub Desktop.
Save jainraj/4b5481d7d865b66a193cece011d61934 to your computer and use it in GitHub Desktop.
Variational LSTM implemented based on Gal & Ghahramani's paper - https://arxiv.org/abs/1512.05287
from pytorch_lightning.utilities.seed import seed_everything
from torch import device as Device, dtype as DType, Tensor
from torch.nn import Parameter, LSTM, init
from typing import List, Tuple, Optional
import torch.nn as nn
import torch
def reverse(lst: List[Tensor]) -> List[Tensor]:
return lst[::-1]
def init_stacked_rnn(layer, model):
layers = nn.ModuleList()
dirs = 2 if model.bidirectional else 1
layers.append(layer(input_size=model.input_size, hidden_size=model.hidden_size, dropout=model.dropout))
for _ in range(model.num_layers - 1):
layers.append(layer(input_size=model.hidden_size * dirs, hidden_size=model.hidden_size, dropout=model.dropout))
return layers
def reorder(x):
return x.view([-1] + list(x.shape[2:]))
def flatten_states(states: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]:
"""We have list[(h, c)]. We want to convert to (stack_list[h], stack_list[c])"""
hs, cs = [], []
for state in states:
hs.append(state[0])
cs.append(state[1])
return torch.stack(hs), torch.stack(cs)
def double_flatten_states(states: List[List[Tuple[Tensor, Tensor]]]):
first_flatten: List[Tuple[Tensor, Tensor]] = []
for inner in states: # inner: List[Tuple[Tensor, Tensor]]
first_flatten.append(flatten_states(inner))
second_flatten = flatten_states(first_flatten)
return reorder(second_flatten[0]), reorder(second_flatten[1])
class LSTMCell(nn.Module):
def __init__(self, input_size: int, hidden_size: int):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.weight_ih = Parameter(torch.Tensor(4 * hidden_size, input_size))
self.weight_hh = Parameter(torch.Tensor(4 * hidden_size, hidden_size))
self.bias_ih = Parameter(torch.Tensor(4 * hidden_size))
self.bias_hh = Parameter(torch.Tensor(4 * hidden_size))
def forward(self,
inpt: Tensor, # N, I
state: Tuple[Tensor, Tensor], # All N, H
input_masks: Tensor, # 4, N, I
hidden_masks: Tensor, # 4, N, H
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: # All N, H
prev_h, prev_c = state
input_weights = self.weight_ih.reshape(4, self.hidden_size, self.input_size).permute((0, 2, 1)) # 4, I, H
input_biases = self.bias_ih.reshape(4, 1, self.hidden_size) # 4, 1, H
hidden_weights = self.weight_hh.reshape(4, self.hidden_size, self.hidden_size).permute((0, 2, 1)) # 4, H, H
hidden_biases = self.bias_hh.reshape(4, 1, self.hidden_size) # 4, 1, H
masked_input = torch.mul(inpt, input_masks) # 4, N, I
masked_hidden = torch.mul(prev_h, hidden_masks) # 4, N, H
gates = torch.baddbmm(input_biases, masked_input, input_weights) + \
torch.baddbmm(hidden_biases, masked_hidden, hidden_weights) # 4, N, H
input_gate, forget_gate, candidate_c, output_gate = gates.unbind(0) # All N, H
input_gate = torch.sigmoid(input_gate) # N, H
forget_gate = torch.sigmoid(forget_gate) # N, H
candidate_c = torch.tanh(candidate_c) # N, H
output_gate = torch.sigmoid(output_gate) # N, H
new_c = (forget_gate * prev_c) + (input_gate * candidate_c) # N, H
new_h = output_gate * torch.tanh(new_c) # N, H
return new_h, (new_h, new_c)
class CoreLSTMLayer(nn.Module):
def __init__(self, dropout: float, **cell_kwargs):
super().__init__()
assert 0 <= dropout < 1
self.dropout = dropout
self.cell = LSTMCell(**cell_kwargs)
def get_dropout_masks(self, batch_size: int, input_size: int, hidden_size: int, device: Device, dtype: DType):
input_masks = torch.ones((4, batch_size, input_size), device=device, dtype=dtype) # 4, N, I
hidden_masks = torch.ones((4, batch_size, hidden_size), device=device, dtype=dtype) # 4, N, H
if self.training and self.dropout > 0:
keep_prob = 1 - self.dropout
# sample dropout once for the sequence & scale (inverted dropout method)
return torch.bernoulli(input_masks * keep_prob) / keep_prob, torch.bernoulli(hidden_masks * keep_prob) / keep_prob
else:
# no dropout & no scaling
return input_masks, hidden_masks
def forward(self,
inputs: List[Tensor], # N, I of length T
state: Tuple[Tensor, Tensor] # All N, H
) -> Tuple[List[Tensor], Tuple[Tensor, Tensor]]: # N, T, H & All N, H
batch_size = inputs[0].size(dim=0)
device, dtype = inputs[0].device, inputs[0].dtype
input_size, hidden_size = self.cell.input_size, self.cell.hidden_size
input_masks, hidden_masks = self.get_dropout_masks(batch_size, input_size, hidden_size, device, dtype)
outputs: List[Tensor] = []
for i in range(len(inputs)):
out, state = self.cell(inputs[i], state, input_masks, hidden_masks)
outputs += [out]
return outputs, state
class LSTMLayer(nn.Module):
def __init__(self, **core_kwargs):
super().__init__()
self.core_layer = CoreLSTMLayer(**core_kwargs)
def forward(self,
inpt: Tensor, # N, T, I
state: Tuple[Tensor, Tensor] # All N, H
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: # N, T, H & All N, H
inputs = inpt.unbind(1) # split on time axis
outputs, state = self.core_layer(inputs, state)
return torch.stack(outputs, dim=1), state
class ReverseLSTMLayer(nn.Module):
def __init__(self, **core_kwargs):
super().__init__()
self.core_layer = CoreLSTMLayer(**core_kwargs)
def forward(self,
inpt: Tensor, # N, T, I
state: Tuple[Tensor, Tensor] # All N, H
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: # N, T, H & All N, H
inputs = reverse(inpt.unbind(1)) # split on time axis
outputs, state = self.core_layer(inputs, state)
return torch.stack(reverse(outputs), dim=1), state
class BiDirLSTMLayer(nn.Module):
def __init__(self, **core_kwargs):
super().__init__()
self.directions = nn.ModuleList(
[
LSTMLayer(**core_kwargs),
ReverseLSTMLayer(**core_kwargs),
]
)
def forward(self,
inpt: Tensor, # N, T, I
states: Tuple[Tensor, Tensor] # Both 2, N, H
) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]: # N, T, 2H & All N, H
outputs: List[Tensor] = []
output_states: List[Tuple[Tensor, Tensor]] = []
for i, direction in enumerate(self.directions):
state = states[0][i, :, :], states[1][i, :, :]
out, out_state = direction(inpt, state)
outputs += [out]
output_states += [out_state]
return torch.cat(outputs, -1), output_states
class StackedUniLSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, dropout):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bidirectional = False
self.dropout = dropout
self.layers = init_stacked_rnn(LSTMLayer, self)
def forward(self,
inpt: Tensor, # N, T, I
init_states: Optional[Tuple[Tensor, Tensor]] = None # Both L, N, H
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: # N, T, H & All L, N, H
output_states: List[Tuple[Tensor, Tensor]] = []
output = inpt
batch_size = inpt.size(dim=0)
if init_states is None:
device, dtype = inpt.device, inpt.dtype
zeros = torch.zeros((self.num_layers, batch_size, self.hidden_size), device=device, dtype=dtype)
init_states = (zeros, zeros)
for i, rnn_layer in enumerate(self.layers):
init_state = init_states[0][i, :, :], init_states[1][i, :, :]
output, out_state = rnn_layer(output, init_state)
output_states += [out_state]
flattened_output_states = flatten_states(output_states)
return output, flattened_output_states
class StackedBiLSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, dropout):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bidirectional = True
self.dropout = dropout
self.layers = init_stacked_rnn(BiDirLSTMLayer, self)
def forward(self,
inpt: Tensor, # N, T, I
states: Optional[Tuple[Tensor, Tensor]] = None # Both 2L, N, H
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: # N, T, 2H & All 2L, N, H
output_states: List[List[Tuple[Tensor, Tensor]]] = []
output = inpt
batch_size = inpt.size(dim=0)
if states is None:
device, dtype = inpt.device, inpt.dtype
zeros = torch.zeros((2 * self.num_layers, batch_size, self.hidden_size), device=device, dtype=dtype)
states = (zeros, zeros)
for i, rnn_layer in enumerate(self.layers):
init_state = states[0][2 * i:2 * i + 2, :, :], states[1][2 * i:2 * i + 2, :, :]
output, out_state = rnn_layer(output, init_state)
output_states += [out_state]
flattened_output_states = double_flatten_states(output_states)
return output, flattened_output_states
def variational_lstm(
input_size: int,
hidden_size: int,
num_layers: int,
bidirectional: bool,
dropout: float,
):
"""
Returns a ScriptModule implementation of Variational LSTM
(Ref: https://papers.nips.cc/paper_files/paper/2016/hash/076a0c97d09cf1a0ec3e19c7f2529f2b-Abstract.html)
When dropout=0, it behaves like PyTorch's LSTM (with dropout=0)
Dropout probability is tied i.e., one dropout value for all layers
"""
stack_type = StackedBiLSTM if bidirectional else StackedUniLSTM
module = stack_type(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
)
return torch.jit.script(module)
def init_lstm(lstm):
lstm.double()
seed_everything(0)
for name, param in lstm.named_parameters():
if 'weight_ih' in name:
init.xavier_uniform_(param)
elif 'weight_hh' in name:
init.orthogonal_(param)
elif 'bias_ih' in name:
init.zeros_(param)
elif 'bias_hh' in name:
param.data = torch.tensor([0] * 3 + [1] * 3 + [0] * 3 * 2,
dtype=torch.double)
return lstm
def get_pyt_lstm(num_layers, bidirectional):
pyt_lstm = LSTM(input_size=7, hidden_size=3, num_layers=num_layers,
bidirectional=bidirectional, dropout=0, batch_first=True)
pyt_lstm = init_lstm(pyt_lstm)
return pyt_lstm
def get_var_lstm(num_layers, bidirectional, dropout=0.):
var_lstm = variational_lstm(input_size=7, hidden_size=3, num_layers=num_layers,
bidirectional=bidirectional, dropout=dropout)
var_lstm = init_lstm(var_lstm)
return var_lstm
def test_lstm():
"""Test Variational LSTM with PyTorch's LSTM when dropout = 0."""
for num_layers in [1, 2, 3]:
for bidirectional in [True, False]:
pyt_lstm = get_pyt_lstm(num_layers, bidirectional)
var_lstm = get_var_lstm(num_layers, bidirectional)
seed_everything(0)
inp = torch.randn(64, 4, 7, dtype=torch.double)
nl = 2 * num_layers if bidirectional else num_layers
state = (torch.randn(nl, 64, 3, dtype=torch.double),
torch.randn(nl, 64, 3, dtype=torch.double))
pyt_out, (pyt_h, pyt_c) = pyt_lstm(inp, state)
var_out, (var_h, var_c) = var_lstm(inp, state)
assert torch.allclose(pyt_out, var_out, rtol=0, atol=1e-15)
assert torch.allclose(pyt_h, var_h, rtol=0, atol=1e-15)
assert torch.allclose(pyt_c, var_c, rtol=0, atol=1e-15)
if __name__ == '__main__':
test_lstm()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment