Created
June 22, 2022 11:06
-
-
Save altescy/05be001aa9d4267cdd11a047da5c9c5e to your computer and use it in GitHub Desktop.
PyTorch implementation of SetTransformer: https://api.semanticscholar.org/CorpusID:59222677
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Optional, cast | |
import torch | |
class MAB(torch.nn.Module): | |
def __init__( | |
self, | |
input_dim: int, | |
num_heads: int, | |
dropout: float = 0.0, | |
) -> None: | |
super().__init__() | |
self._multihead = torch.nn.MultiheadAttention( | |
embed_dim=input_dim, | |
num_heads=num_heads, | |
dropout=dropout, | |
batch_first=True, | |
) | |
self._feedforward = torch.nn.Sequential( | |
torch.nn.Linear(input_dim, input_dim), | |
torch.nn.GELU(), | |
torch.nn.Linear(input_dim, input_dim), | |
) | |
self._layernorm_1 = torch.nn.LayerNorm(input_dim) | |
self._layernorm_2 = torch.nn.LayerNorm(input_dim) | |
def forward( | |
self, | |
X: torch.Tensor, | |
Y: torch.Tensor, | |
mask: Optional[torch.Tensor] = None, | |
) -> torch.Tensor: | |
if mask is not None: | |
mask = ~mask | |
H, _ = self._multihead(X, Y, Y, mask) | |
H = self._layernorm_1(X + H) | |
H = self._layernorm_2(H + self._feedforward(H)) | |
return cast(torch.Tensor, H) | |
class SAB(torch.nn.Module): | |
def __init__( | |
self, | |
input_dim: int, | |
num_heads: int, | |
dropout: float = 0.0, | |
) -> None: | |
super().__init__() | |
self._mab = MAB(input_dim, num_heads, dropout) | |
def forward( | |
self, | |
X: torch.Tensor, | |
mask: Optional[torch.Tensor] = None, | |
) -> torch.Tensor: | |
return cast(torch.Tensor, self._mab(X, X, mask)) | |
class PMA(torch.nn.Module): | |
def __init__( | |
self, | |
input_dim: int, | |
num_heads: int, | |
num_seeds: int = 1, | |
dropout: float = 0.0, | |
) -> None: | |
super().__init__() | |
self._mab = MAB(input_dim, num_heads, dropout) | |
self._S = torch.nn.Parameter(torch.randn((num_seeds, input_dim))) | |
self._feedforward = torch.nn.Sequential( | |
torch.nn.Linear(input_dim, input_dim), | |
torch.nn.GELU(), | |
torch.nn.Linear(input_dim, input_dim), | |
) | |
def forward( | |
self, | |
Z: torch.Tensor, | |
mask: Optional[torch.Tensor] = None, | |
) -> torch.Tensor: | |
S = self._S.unsqueeze(0).expand(Z.size(0), -1, -1) | |
return cast(torch.Tensor, self._mab(S, self._feedforward(Z), mask)) | |
class Encoder(torch.nn.Module): | |
def __init__( | |
self, | |
input_dim: int, | |
num_heads: int, | |
dropout: float = 0.0, | |
) -> None: | |
super().__init__() | |
self._sab_1 = SAB(input_dim, num_heads, dropout) | |
self._sab_2 = SAB(input_dim, num_heads, dropout) | |
def forward( | |
self, | |
X: torch.Tensor, | |
mask: Optional[torch.Tensor] = None, | |
) -> torch.Tensor: | |
H = self._sab_1(X, mask) | |
H = self._sab_2(H, mask) | |
return cast(torch.Tensor, H) | |
class Decoder(torch.nn.Module): | |
def __init__( | |
self, | |
input_dim: int, | |
num_heads: int, | |
num_seeds: int = 1, | |
dropout: float = 0.0, | |
) -> None: | |
super().__init__() | |
self._sab = SAB(input_dim, num_heads, dropout) | |
self._pma = PMA(input_dim, num_heads, num_seeds, dropout) | |
self._feedforward = torch.nn.Sequential( | |
torch.nn.Linear(input_dim, input_dim), | |
torch.nn.GELU(), | |
torch.nn.Linear(input_dim, input_dim), | |
) | |
def forward( | |
self, | |
Z: torch.Tensor, | |
mask: Optional[torch.Tensor] = None, | |
) -> torch.Tensor: | |
H = self._pma(Z, mask) | |
H = self._sab(H) | |
H = self._feedforward(H) | |
return cast(torch.Tensor, H) | |
class SetTransformer(torch.nn.Module): | |
def __init__( | |
self, | |
input_dim: int, | |
num_heads: int, | |
num_seeds: int = 1, | |
dropout: float = 0.0, | |
) -> None: | |
super().__init__() | |
self._encoder = Encoder(input_dim, num_heads, dropout) | |
self._decoder = Decoder(input_dim, num_heads, num_seeds, dropout) | |
self._output_dim = num_seeds * input_dim | |
def forward( | |
self, | |
X: torch.Tensor, | |
mask: Optional[torch.Tensor] = None, | |
) -> torch.Tensor: | |
Z = self._encoder(X, mask) | |
H = self._decoder(Z, mask) | |
return cast(torch.Tensor, H.view(X.size(0), -1)) | |
def get_output_dim(self) -> int: | |
return self._output_dim |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment