Skip to content

Instantly share code, notes, and snippets.

@shawntan
Created September 7, 2020 18:19
Show Gist options
  • Save shawntan/bfa0babf7273d6f26af930ac4bbb3d74 to your computer and use it in GitHub Desktop.
Save shawntan/bfa0babf7273d6f26af930ac4bbb3d74 to your computer and use it in GitHub Desktop.
import torch
from torch import nn
class RNNCell(nn.Module):
def __init__(self, hidden_size, dropout, activation=None):
super(RNNCell, self).__init__()
self.hidden_size = hidden_size
self.activation = activation
linear_transform = nn.Linear(hidden_size * 2, hidden_size)
torch.nn.init.orthogonal_(linear_transform.weight)
torch.nn.init.zeros_(linear_transform.bias)
self.input_t = nn.Sequential(
linear_transform,
self.activation,
nn.Dropout(dropout),
)
def forward(self, vi, hi):
input = torch.cat([vi, hi], dim=-1)
output = self.input_t(input)
return output
class Inside(nn.Module):
def __init__(self, compose):
super(Inside, self).__init__()
self.compose = compose
self.compose_attn = nn.Sequential(
nn.Linear(self.compose.hidden_size, 1),
nn.Softmax(dim=1)
)
def forward(self, X:torch.Tensor, lengths:torch.LongTensor):
batch_idxs = torch.arange(X.size(1), dtype=torch.long, device=X.device)
pos_idxs = torch.arange(X.size(0), dtype=torch.long, device=X.device)
length_mask = pos_idxs[:, None] < lengths[None, :]
table_mask = length_mask.expand(X.size(0), -1, -1)
table = self.fill_table(X)
final_state = table[0, lengths - 1, batch_idxs]
upper_tri_mask = torch.triu(torch.ones((X.size(0), X.size(0)),
dtype=torch.bool,
device=X.device), diagonal=1)
context = table[upper_tri_mask]
context_mask = table_mask[upper_tri_mask]
return final_state, context, context_mask
def fill_table(self, X:torch.FloatTensor):
# Setup
length, batch_size, hidden_size = X.size()
idxs = torch.arange(length, dtype=torch.long, device=X.device)
# Table
embs = torch.empty((length, length, batch_size, hidden_size),
device=X.device, dtype=torch.float)
# Initialise diagonal
embs[idxs, idxs] = X
for d in range(1, length): # d : depth
span_idxs = idxs[:d]
# Compute left and right indices
l_i = idxs[:-d, None] # length - d, 1
r_j = idxs[d:, None] # length - d, 1
l_j = l_i + span_idxs[None, :] # length - d, d
r_i = r_j + span_idxs[None, :] - (d - 1) # length - d, d
# Extract embeddings from chart
l_hid = embs[l_i, l_j]
r_hid = embs[r_i, r_j]
# Compose
h = self.compose(l_hid, r_hid)
a = self.compose_attn(h)
if d > 1:
n_h = torch.matmul(a.permute(0, 2, 3, 1),
h.permute(0, 2, 1, 3))[:, :, 0]
else:
n_h = h[:, :, 0]
# Fill table.
embs[l_i[:, 0], r_j[:, 0]] = n_h
return embs
if __name__ == "__main__":
in_emb = torch.randn(20, 2, 20)
lengths = torch.tensor([5, 20])
inside = Inside(compose=RNNCell(20, 0.5, nn.Tanh()))
h = inside(in_emb, lengths)
print(h.size())
h.sum().backward()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment