Created
March 4, 2020 19:21
-
-
Save shaabhishek/28cccc4853299f376962645c5e8531db to your computer and use it in GitHub Desktop.
Gist to verify 1. how packing and padding sequences works, and 2. how bidirectional rnns work
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#questions to answer | |
# 1. Do the following sequences give the same hidden states: pad-rnn-hidden & pack-rnn-pad-hidden | |
# 2. does the birnn use the (hidden states computed in the forward states) to | |
# compute the (hidden states computed in the backward direction) | |
import torch | |
import torch.nn as nn | |
import torch.nn.utils.rnn as rnnutils | |
def pprint_dict(d:dict): | |
for k,v in d.items(): | |
print(f"{k}: \n{v}\n") | |
def get_param_dict(m:nn.Module): | |
return {n:p.data for n,p in m.named_parameters()} | |
def unpack_sequence(packed_sequence, lengths): | |
assert isinstance(packed_sequence, rnnutils.PackedSequence) | |
head = 0 | |
trailing_dims = packed_sequence.data.shape[1:] | |
unpacked_sequence = [torch.zeros(l, *trailing_dims) for l in lengths] | |
# l_idx - goes from 0 - maxLen-1 | |
for l_idx, b_size in enumerate(packed_sequence.batch_sizes): | |
for b_idx in range(b_size): | |
unpacked_sequence[b_idx][l_idx] = packed_sequence.data[head] | |
head += 1 | |
return unpacked_sequence | |
def make_rnns(): | |
# setup the networks | |
hidden_dim = 2 | |
input_dim = 1 | |
nn.RNN(input_dim, hidden_dim) | |
fwrnn = nn.RNN(input_dim, hidden_dim, batch_first=True) | |
bwrnn = nn.RNN(input_dim, hidden_dim, batch_first=True) | |
birnn = nn.RNN(input_dim, hidden_dim, bidirectional=True, batch_first=True) | |
# tie the initial weights | |
with torch.no_grad(): | |
fwrnn.weight_hh_l0.copy_(birnn.weight_hh_l0.data) | |
fwrnn.weight_ih_l0.copy_(birnn.weight_ih_l0.data) | |
fwrnn.bias_hh_l0.copy_(birnn.bias_hh_l0) | |
fwrnn.bias_ih_l0.copy_(birnn.bias_ih_l0) | |
bwrnn.weight_hh_l0.copy_(birnn.weight_hh_l0_reverse.data) | |
bwrnn.weight_ih_l0.copy_(birnn.weight_ih_l0_reverse.data) | |
bwrnn.bias_hh_l0.copy_(birnn.bias_hh_l0_reverse) | |
bwrnn.bias_ih_l0.copy_(birnn.bias_ih_l0_reverse) | |
return fwrnn, bwrnn, birnn | |
# setup the dataset | |
bs = 3 | |
minT = 2 | |
maxT = 5 | |
x_dim = 1 | |
X_len = sorted(list(torch.randint(minT,maxT+1,(bs,)).numpy()), reverse=True) | |
X = [torch.rand(l,x_dim) for l in X_len] | |
# get rnns | |
fwrnn, bwrnn, birnn = make_rnns() | |
# experiment 1: Do the following sequences give the same hidden states: | |
# pad-rnn-hidden | |
# pack-rnn-hidden-pad | |
# pad-pack-rnn-hidden-pad | |
# Answer: Yes. But be careful to compare only the relevant entries for each batch-element | |
def expt_1(): | |
X_pad = rnnutils.pad_sequence(X, batch_first=True) #(bs,maxT,*) | |
h_fw_pad_only = fwrnn(X_pad)[0] #(bs,maxT,*) | |
X_pack = rnnutils.pack_sequence(X, enforce_sorted=True) | |
h_fw_pack = fwrnn(X_pack)[0] | |
h_fw_pack_pad = rnnutils.pad_packed_sequence(h_fw_pack, batch_first=True)[0] #(bs,maxT,*) | |
X_pad_pack = rnnutils.pack_padded_sequence(X_pad, X_len, batch_first=True) | |
h_fw_pad_pack = fwrnn(X_pad_pack)[0] | |
h_fw_pad_pack_pad = rnnutils.pad_packed_sequence(h_fw_pad_pack, batch_first=True)[0] #(bs,maxT,*) | |
# are the fw(1dir, pad only) & fw(1dir, pack-pad) the same hidden states? A: NO!! | |
print([torch.allclose(h_fw_pad_only[b_idx, :l], h_fw_pack_pad[b_idx, :l]) for b_idx, l in enumerate(X_len)]) | |
print([torch.allclose(h_fw_pad_only[b_idx, :l], h_fw_pad_pack_pad[b_idx, :l]) for b_idx, l in enumerate(X_len)]) | |
print([torch.allclose(h_fw_pack_pad[b_idx, :l], h_fw_pad_pack_pad[b_idx, :l]) for b_idx, l in enumerate(X_len)]) | |
# experiment 2: does the birnn use the (hidden states computed in the forward states) to compute the (hidden states computed in the backward direction) | |
# modified experiment 2: does the birnn concat the X & (hidden states computed in the forward states) to compute the (hidden states computed in the backward direction) | |
# answer - possibly not. reason: for the birnn, the fw weights are the same shape as the bw weights, which should not be the case if the backward direction inputs had the dimensionality (input_dim + hidden_fw_dim) in the case of concatenating the inputs | |
# modified modified experiment 2: is the bw hidden state from birnn the same as the hidden state from bwrnn? | |
# answer according to this paper should be "YES": Schuster, Mike, and Kuldip K. Paliwal. "Bidirectional recurrent neural networks" | |
def expt_2(): | |
X_pack = rnnutils.pack_sequence(X, enforce_sorted=True) | |
h_fw_pack = fwrnn(X_pack)[0] | |
# extract unpacked hidden seq | |
# concatenate them -> Xh | |
# reverse Xh | |
# pack them | |
# run them through bw rnn | |
# also run X directly through birnn | |
h_fw = unpack_sequence(h_fw_pack, X_len) | |
Xh = [torch.cat([X[b], h_fw[b]], dim=-1) for b in range(bs)] | |
Xh_rev = [reversed(_Xh) for _Xh in Xh] #reverses on the first dimension | |
Xh_pack = rnnutils.pack_sequence(Xh_rev, enforce_sorted=True) | |
h_bw_pack = bwrnn(Xh_pack)[0] #(bs,maxT,*) | |
h_bi_pack = birnn(X_pack)[0] | |
h_bi_pack_pad = rnnutils.pad_packed_sequence(h_bi_pack, batch_first=True)[0].view(bs,maxT,2,hidden_dim) #(bs,maxT,2,*) | |
print([torch.allclose(h_fw_pack[b_idx, :l], h_bi_pack_pad[b_idx, :l, 0]) for b_idx, l in enumerate(X_len)]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment