Skip to content

Instantly share code, notes, and snippets.

@Mason-McGough
Last active October 2, 2023 09:55
Show Gist options
  • Save Mason-McGough/fcb4a88fd47dcf7a47c1f9c72e778f85 to your computer and use it in GitHub Desktop.
Save Mason-McGough/fcb4a88fd47dcf7a47c1f9c72e778f85 to your computer and use it in GitHub Desktop.
Pointer network attention architecture in PyTorch
class PointerNetwork(nn.Module):
"""
From "Pointer Networks" by Vinyals et al. (2017)
Adapted from pointer-networks-pytorch by ast0414:
https://github.com/ast0414/pointer-networks-pytorch
Args:
n_hidden: The number of features to expect in the inputs.
"""
def __init__(
self,
n_hidden: int
):
super().__init__()
self.n_hidden = n_hidden
self.w1 = nn.Linear(n_hidden, n_hidden, bias=False)
self.w2 = nn.Linear(n_hidden, n_hidden, bias=False)
self.v = nn.Linear(n_hidden, 1, bias=False)
def forward(
self,
x_decoder: torch.Tensor,
x_encoder: torch.Tensor,
mask: torch.Tensor,
eps: float = 1e-16
) -> torch.Tensor:
"""
Args:
x_decoder: Encoding over the output tokens.
x_encoder: Encoding over the input tokens.
mask: Binary mask over the softmax input.
Shape:
x_decoder: (B, Ne, C)
x_encoder: (B, Nd, C)
mask: (B, Nd, Ne)
"""
# (B, Nd, Ne, C) <- (B, Ne, C)
encoder_transform = self.w1(x_encoder).unsqueeze(1).expand(
-1, x_decoder.shape[1], -1, -1)
# (B, Nd, 1, C) <- (B, Nd, C)
decoder_transform = self.w2(x_decoder).unsqueeze(2)
# (B, Nd, Ne) <- (B, Nd, Ne, C), (B, Nd, 1, C)
prod = self.v(torch.tanh(encoder_transform + decoder_transform)).squeeze(-1)
# (B, Nd, Ne) <- (B, Nd, Ne)
log_score = masked_log_softmax(prod, mask, dim=-1, eps=eps)
return log_score
@Mason-McGough
Copy link
Author

Hi there, thank you for your interest in my article. Do you think it could be related to this question? The error message sounds familiar.

In either case, I believe it should be fixed now. Please give it a try and let me know if you still encounter issues.

@Ririkosann
Copy link

As you said, it was fixed.
Thank you very much!

@Mason-McGough
Copy link
Author

Happy to help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment