Skip to content

Instantly share code, notes, and snippets.

@shaabhishek
Created March 4, 2020 19:21
Show Gist options
  • Save shaabhishek/28cccc4853299f376962645c5e8531db to your computer and use it in GitHub Desktop.
Save shaabhishek/28cccc4853299f376962645c5e8531db to your computer and use it in GitHub Desktop.
Gist to verify 1. how packing and padding sequences works, and 2. how bidirectional rnns work
#questions to answer
# 1. Do the following sequences give the same hidden states: pad-rnn-hidden & pack-rnn-pad-hidden
# 2. does the birnn use the (hidden states computed in the forward states) to
# compute the (hidden states computed in the backward direction)
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnnutils
def pprint_dict(d:dict):
for k,v in d.items():
print(f"{k}: \n{v}\n")
def get_param_dict(m:nn.Module):
return {n:p.data for n,p in m.named_parameters()}
def unpack_sequence(packed_sequence, lengths):
assert isinstance(packed_sequence, rnnutils.PackedSequence)
head = 0
trailing_dims = packed_sequence.data.shape[1:]
unpacked_sequence = [torch.zeros(l, *trailing_dims) for l in lengths]
# l_idx - goes from 0 - maxLen-1
for l_idx, b_size in enumerate(packed_sequence.batch_sizes):
for b_idx in range(b_size):
unpacked_sequence[b_idx][l_idx] = packed_sequence.data[head]
head += 1
return unpacked_sequence
def make_rnns():
# setup the networks
hidden_dim = 2
input_dim = 1
nn.RNN(input_dim, hidden_dim)
fwrnn = nn.RNN(input_dim, hidden_dim, batch_first=True)
bwrnn = nn.RNN(input_dim, hidden_dim, batch_first=True)
birnn = nn.RNN(input_dim, hidden_dim, bidirectional=True, batch_first=True)
# tie the initial weights
with torch.no_grad():
fwrnn.weight_hh_l0.copy_(birnn.weight_hh_l0.data)
fwrnn.weight_ih_l0.copy_(birnn.weight_ih_l0.data)
fwrnn.bias_hh_l0.copy_(birnn.bias_hh_l0)
fwrnn.bias_ih_l0.copy_(birnn.bias_ih_l0)
bwrnn.weight_hh_l0.copy_(birnn.weight_hh_l0_reverse.data)
bwrnn.weight_ih_l0.copy_(birnn.weight_ih_l0_reverse.data)
bwrnn.bias_hh_l0.copy_(birnn.bias_hh_l0_reverse)
bwrnn.bias_ih_l0.copy_(birnn.bias_ih_l0_reverse)
return fwrnn, bwrnn, birnn
# setup the dataset
bs = 3
minT = 2
maxT = 5
x_dim = 1
X_len = sorted(list(torch.randint(minT,maxT+1,(bs,)).numpy()), reverse=True)
X = [torch.rand(l,x_dim) for l in X_len]
# get rnns
fwrnn, bwrnn, birnn = make_rnns()
# experiment 1: Do the following sequences give the same hidden states:
# pad-rnn-hidden
# pack-rnn-hidden-pad
# pad-pack-rnn-hidden-pad
# Answer: Yes. But be careful to compare only the relevant entries for each batch-element
def expt_1():
X_pad = rnnutils.pad_sequence(X, batch_first=True) #(bs,maxT,*)
h_fw_pad_only = fwrnn(X_pad)[0] #(bs,maxT,*)
X_pack = rnnutils.pack_sequence(X, enforce_sorted=True)
h_fw_pack = fwrnn(X_pack)[0]
h_fw_pack_pad = rnnutils.pad_packed_sequence(h_fw_pack, batch_first=True)[0] #(bs,maxT,*)
X_pad_pack = rnnutils.pack_padded_sequence(X_pad, X_len, batch_first=True)
h_fw_pad_pack = fwrnn(X_pad_pack)[0]
h_fw_pad_pack_pad = rnnutils.pad_packed_sequence(h_fw_pad_pack, batch_first=True)[0] #(bs,maxT,*)
# are the fw(1dir, pad only) & fw(1dir, pack-pad) the same hidden states? A: NO!!
print([torch.allclose(h_fw_pad_only[b_idx, :l], h_fw_pack_pad[b_idx, :l]) for b_idx, l in enumerate(X_len)])
print([torch.allclose(h_fw_pad_only[b_idx, :l], h_fw_pad_pack_pad[b_idx, :l]) for b_idx, l in enumerate(X_len)])
print([torch.allclose(h_fw_pack_pad[b_idx, :l], h_fw_pad_pack_pad[b_idx, :l]) for b_idx, l in enumerate(X_len)])
# experiment 2: does the birnn use the (hidden states computed in the forward states) to compute the (hidden states computed in the backward direction)
# modified experiment 2: does the birnn concat the X & (hidden states computed in the forward states) to compute the (hidden states computed in the backward direction)
# answer - possibly not. reason: for the birnn, the fw weights are the same shape as the bw weights, which should not be the case if the backward direction inputs had the dimensionality (input_dim + hidden_fw_dim) in the case of concatenating the inputs
# modified modified experiment 2: is the bw hidden state from birnn the same as the hidden state from bwrnn?
# answer according to this paper should be "YES": Schuster, Mike, and Kuldip K. Paliwal. "Bidirectional recurrent neural networks"
def expt_2():
X_pack = rnnutils.pack_sequence(X, enforce_sorted=True)
h_fw_pack = fwrnn(X_pack)[0]
# extract unpacked hidden seq
# concatenate them -> Xh
# reverse Xh
# pack them
# run them through bw rnn
# also run X directly through birnn
h_fw = unpack_sequence(h_fw_pack, X_len)
Xh = [torch.cat([X[b], h_fw[b]], dim=-1) for b in range(bs)]
Xh_rev = [reversed(_Xh) for _Xh in Xh] #reverses on the first dimension
Xh_pack = rnnutils.pack_sequence(Xh_rev, enforce_sorted=True)
h_bw_pack = bwrnn(Xh_pack)[0] #(bs,maxT,*)
h_bi_pack = birnn(X_pack)[0]
h_bi_pack_pad = rnnutils.pad_packed_sequence(h_bi_pack, batch_first=True)[0].view(bs,maxT,2,hidden_dim) #(bs,maxT,2,*)
print([torch.allclose(h_fw_pack[b_idx, :l], h_bi_pack_pad[b_idx, :l, 0]) for b_idx, l in enumerate(X_len)])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment