Last active
October 2, 2023 09:55
-
-
Save Mason-McGough/fcb4a88fd47dcf7a47c1f9c72e778f85 to your computer and use it in GitHub Desktop.
Pointer network attention architecture in PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class PointerNetwork(nn.Module): | |
""" | |
From "Pointer Networks" by Vinyals et al. (2017) | |
Adapted from pointer-networks-pytorch by ast0414: | |
https://github.com/ast0414/pointer-networks-pytorch | |
Args: | |
n_hidden: The number of features to expect in the inputs. | |
""" | |
def __init__( | |
self, | |
n_hidden: int | |
): | |
super().__init__() | |
self.n_hidden = n_hidden | |
self.w1 = nn.Linear(n_hidden, n_hidden, bias=False) | |
self.w2 = nn.Linear(n_hidden, n_hidden, bias=False) | |
self.v = nn.Linear(n_hidden, 1, bias=False) | |
def forward( | |
self, | |
x_decoder: torch.Tensor, | |
x_encoder: torch.Tensor, | |
mask: torch.Tensor, | |
eps: float = 1e-16 | |
) -> torch.Tensor: | |
""" | |
Args: | |
x_decoder: Encoding over the output tokens. | |
x_encoder: Encoding over the input tokens. | |
mask: Binary mask over the softmax input. | |
Shape: | |
x_decoder: (B, Ne, C) | |
x_encoder: (B, Nd, C) | |
mask: (B, Nd, Ne) | |
""" | |
# (B, Nd, Ne, C) <- (B, Ne, C) | |
encoder_transform = self.w1(x_encoder).unsqueeze(1).expand( | |
-1, x_decoder.shape[1], -1, -1) | |
# (B, Nd, 1, C) <- (B, Nd, C) | |
decoder_transform = self.w2(x_decoder).unsqueeze(2) | |
# (B, Nd, Ne) <- (B, Nd, Ne, C), (B, Nd, 1, C) | |
prod = self.v(torch.tanh(encoder_transform + decoder_transform)).squeeze(-1) | |
# (B, Nd, Ne) <- (B, Nd, Ne) | |
log_score = masked_log_softmax(prod, mask, dim=-1, eps=eps) | |
return log_score |
As you said, it was fixed.
Thank you very much!
Happy to help!
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi there, thank you for your interest in my article. Do you think it could be related to this question? The error message sounds familiar.
In either case, I believe it should be fixed now. Please give it a try and let me know if you still encounter issues.