Created
September 7, 2020 18:19
-
-
Save shawntan/bfa0babf7273d6f26af930ac4bbb3d74 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torch import nn | |
| class RNNCell(nn.Module): | |
| def __init__(self, hidden_size, dropout, activation=None): | |
| super(RNNCell, self).__init__() | |
| self.hidden_size = hidden_size | |
| self.activation = activation | |
| linear_transform = nn.Linear(hidden_size * 2, hidden_size) | |
| torch.nn.init.orthogonal_(linear_transform.weight) | |
| torch.nn.init.zeros_(linear_transform.bias) | |
| self.input_t = nn.Sequential( | |
| linear_transform, | |
| self.activation, | |
| nn.Dropout(dropout), | |
| ) | |
| def forward(self, vi, hi): | |
| input = torch.cat([vi, hi], dim=-1) | |
| output = self.input_t(input) | |
| return output | |
| class Inside(nn.Module): | |
| def __init__(self, compose): | |
| super(Inside, self).__init__() | |
| self.compose = compose | |
| self.compose_attn = nn.Sequential( | |
| nn.Linear(self.compose.hidden_size, 1), | |
| nn.Softmax(dim=1) | |
| ) | |
| def forward(self, X:torch.Tensor, lengths:torch.LongTensor): | |
| batch_idxs = torch.arange(X.size(1), dtype=torch.long, device=X.device) | |
| pos_idxs = torch.arange(X.size(0), dtype=torch.long, device=X.device) | |
| length_mask = pos_idxs[:, None] < lengths[None, :] | |
| table_mask = length_mask.expand(X.size(0), -1, -1) | |
| table = self.fill_table(X) | |
| final_state = table[0, lengths - 1, batch_idxs] | |
| upper_tri_mask = torch.triu(torch.ones((X.size(0), X.size(0)), | |
| dtype=torch.bool, | |
| device=X.device), diagonal=1) | |
| context = table[upper_tri_mask] | |
| context_mask = table_mask[upper_tri_mask] | |
| return final_state, context, context_mask | |
| def fill_table(self, X:torch.FloatTensor): | |
| # Setup | |
| length, batch_size, hidden_size = X.size() | |
| idxs = torch.arange(length, dtype=torch.long, device=X.device) | |
| # Table | |
| embs = torch.empty((length, length, batch_size, hidden_size), | |
| device=X.device, dtype=torch.float) | |
| # Initialise diagonal | |
| embs[idxs, idxs] = X | |
| for d in range(1, length): # d : depth | |
| span_idxs = idxs[:d] | |
| # Compute left and right indices | |
| l_i = idxs[:-d, None] # length - d, 1 | |
| r_j = idxs[d:, None] # length - d, 1 | |
| l_j = l_i + span_idxs[None, :] # length - d, d | |
| r_i = r_j + span_idxs[None, :] - (d - 1) # length - d, d | |
| # Extract embeddings from chart | |
| l_hid = embs[l_i, l_j] | |
| r_hid = embs[r_i, r_j] | |
| # Compose | |
| h = self.compose(l_hid, r_hid) | |
| a = self.compose_attn(h) | |
| if d > 1: | |
| n_h = torch.matmul(a.permute(0, 2, 3, 1), | |
| h.permute(0, 2, 1, 3))[:, :, 0] | |
| else: | |
| n_h = h[:, :, 0] | |
| # Fill table. | |
| embs[l_i[:, 0], r_j[:, 0]] = n_h | |
| return embs | |
| if __name__ == "__main__": | |
| in_emb = torch.randn(20, 2, 20) | |
| lengths = torch.tensor([5, 20]) | |
| inside = Inside(compose=RNNCell(20, 0.5, nn.Tanh())) | |
| h = inside(in_emb, lengths) | |
| print(h.size()) | |
| h.sum().backward() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment