Last active
October 2, 2023 09:55
-
-
Save Mason-McGough/fcb4a88fd47dcf7a47c1f9c72e778f85 to your computer and use it in GitHub Desktop.
Pointer network attention architecture in PyTorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class PointerNetwork(nn.Module): | |
""" | |
From "Pointer Networks" by Vinyals et al. (2017) | |
Adapted from pointer-networks-pytorch by ast0414: | |
https://github.com/ast0414/pointer-networks-pytorch | |
Args: | |
n_hidden: The number of features to expect in the inputs. | |
""" | |
def __init__( | |
self, | |
n_hidden: int | |
): | |
super().__init__() | |
self.n_hidden = n_hidden | |
self.w1 = nn.Linear(n_hidden, n_hidden, bias=False) | |
self.w2 = nn.Linear(n_hidden, n_hidden, bias=False) | |
self.v = nn.Linear(n_hidden, 1, bias=False) | |
def forward( | |
self, | |
x_decoder: torch.Tensor, | |
x_encoder: torch.Tensor, | |
mask: torch.Tensor, | |
eps: float = 1e-16 | |
) -> torch.Tensor: | |
""" | |
Args: | |
x_decoder: Encoding over the output tokens. | |
x_encoder: Encoding over the input tokens. | |
mask: Binary mask over the softmax input. | |
Shape: | |
x_decoder: (B, Ne, C) | |
x_encoder: (B, Nd, C) | |
mask: (B, Nd, Ne) | |
""" | |
# (B, Nd, Ne, C) <- (B, Ne, C) | |
encoder_transform = self.w1(x_encoder).unsqueeze(1).expand( | |
-1, x_decoder.shape[1], -1, -1) | |
# (B, Nd, 1, C) <- (B, Nd, C) | |
decoder_transform = self.w2(x_decoder).unsqueeze(2) | |
# (B, Nd, Ne) <- (B, Nd, Ne, C), (B, Nd, 1, C) | |
prod = self.v(torch.tanh(encoder_transform + decoder_transform)).squeeze(-1) | |
# (B, Nd, Ne) <- (B, Nd, Ne) | |
log_score = masked_log_softmax(prod, mask, dim=-1, eps=eps) | |
return log_score |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Happy to help!