Last active
October 2, 2023 09:55
-
-
Save Mason-McGough/fcb4a88fd47dcf7a47c1f9c72e778f85 to your computer and use it in GitHub Desktop.
Pointer network attention architecture in PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class PointerNetwork(nn.Module): | |
""" | |
From "Pointer Networks" by Vinyals et al. (2017) | |
Adapted from pointer-networks-pytorch by ast0414: | |
https://github.com/ast0414/pointer-networks-pytorch | |
Args: | |
n_hidden: The number of features to expect in the inputs. | |
""" | |
def __init__( | |
self, | |
n_hidden: int | |
): | |
super().__init__() | |
self.n_hidden = n_hidden | |
self.w1 = nn.Linear(n_hidden, n_hidden, bias=False) | |
self.w2 = nn.Linear(n_hidden, n_hidden, bias=False) | |
self.v = nn.Linear(n_hidden, 1, bias=False) | |
def forward( | |
self, | |
x_decoder: torch.Tensor, | |
x_encoder: torch.Tensor, | |
mask: torch.Tensor, | |
eps: float = 1e-16 | |
) -> torch.Tensor: | |
""" | |
Args: | |
x_decoder: Encoding over the output tokens. | |
x_encoder: Encoding over the input tokens. | |
mask: Binary mask over the softmax input. | |
Shape: | |
x_decoder: (B, Ne, C) | |
x_encoder: (B, Nd, C) | |
mask: (B, Nd, Ne) | |
""" | |
# (B, Nd, Ne, C) <- (B, Ne, C) | |
encoder_transform = self.w1(x_encoder).unsqueeze(1).expand( | |
-1, x_decoder.shape[1], -1, -1) | |
# (B, Nd, 1, C) <- (B, Nd, C) | |
decoder_transform = self.w2(x_decoder).unsqueeze(2) | |
# (B, Nd, Ne) <- (B, Nd, Ne, C), (B, Nd, 1, C) | |
prod = self.v(torch.tanh(encoder_transform + decoder_transform)).squeeze(-1) | |
# (B, Nd, Ne) <- (B, Nd, Ne) | |
log_score = masked_log_softmax(prod, mask, dim=-1, eps=eps) | |
return log_score |
Hi there, thank you for your interest in my article. Do you think it could be related to this question? The error message sounds familiar.
In either case, I believe it should be fixed now. Please give it a try and let me know if you still encounter issues.
As you said, it was fixed.
Thank you very much!
Happy to help!
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hello.
I am writing here to ask a question about the program.
I have read your article on Pointer Networks with Transformers.
It is very interesting as it combines pointer network and Transformer.
I actually ran the Google colab you linked to, but I got a ValueError: Inconsistent coordinate dimensionality when training.
How can I run it?