Skip to content

Instantly share code, notes, and snippets.

@altescy
Created June 22, 2022 11:06
Show Gist options
  • Save altescy/05be001aa9d4267cdd11a047da5c9c5e to your computer and use it in GitHub Desktop.
Save altescy/05be001aa9d4267cdd11a047da5c9c5e to your computer and use it in GitHub Desktop.
PyTorch implementation of SetTransformer: https://api.semanticscholar.org/CorpusID:59222677
from typing import Optional, cast
import torch
class MAB(torch.nn.Module):
def __init__(
self,
input_dim: int,
num_heads: int,
dropout: float = 0.0,
) -> None:
super().__init__()
self._multihead = torch.nn.MultiheadAttention(
embed_dim=input_dim,
num_heads=num_heads,
dropout=dropout,
batch_first=True,
)
self._feedforward = torch.nn.Sequential(
torch.nn.Linear(input_dim, input_dim),
torch.nn.GELU(),
torch.nn.Linear(input_dim, input_dim),
)
self._layernorm_1 = torch.nn.LayerNorm(input_dim)
self._layernorm_2 = torch.nn.LayerNorm(input_dim)
def forward(
self,
X: torch.Tensor,
Y: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if mask is not None:
mask = ~mask
H, _ = self._multihead(X, Y, Y, mask)
H = self._layernorm_1(X + H)
H = self._layernorm_2(H + self._feedforward(H))
return cast(torch.Tensor, H)
class SAB(torch.nn.Module):
def __init__(
self,
input_dim: int,
num_heads: int,
dropout: float = 0.0,
) -> None:
super().__init__()
self._mab = MAB(input_dim, num_heads, dropout)
def forward(
self,
X: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return cast(torch.Tensor, self._mab(X, X, mask))
class PMA(torch.nn.Module):
def __init__(
self,
input_dim: int,
num_heads: int,
num_seeds: int = 1,
dropout: float = 0.0,
) -> None:
super().__init__()
self._mab = MAB(input_dim, num_heads, dropout)
self._S = torch.nn.Parameter(torch.randn((num_seeds, input_dim)))
self._feedforward = torch.nn.Sequential(
torch.nn.Linear(input_dim, input_dim),
torch.nn.GELU(),
torch.nn.Linear(input_dim, input_dim),
)
def forward(
self,
Z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
S = self._S.unsqueeze(0).expand(Z.size(0), -1, -1)
return cast(torch.Tensor, self._mab(S, self._feedforward(Z), mask))
class Encoder(torch.nn.Module):
def __init__(
self,
input_dim: int,
num_heads: int,
dropout: float = 0.0,
) -> None:
super().__init__()
self._sab_1 = SAB(input_dim, num_heads, dropout)
self._sab_2 = SAB(input_dim, num_heads, dropout)
def forward(
self,
X: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
H = self._sab_1(X, mask)
H = self._sab_2(H, mask)
return cast(torch.Tensor, H)
class Decoder(torch.nn.Module):
def __init__(
self,
input_dim: int,
num_heads: int,
num_seeds: int = 1,
dropout: float = 0.0,
) -> None:
super().__init__()
self._sab = SAB(input_dim, num_heads, dropout)
self._pma = PMA(input_dim, num_heads, num_seeds, dropout)
self._feedforward = torch.nn.Sequential(
torch.nn.Linear(input_dim, input_dim),
torch.nn.GELU(),
torch.nn.Linear(input_dim, input_dim),
)
def forward(
self,
Z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
H = self._pma(Z, mask)
H = self._sab(H)
H = self._feedforward(H)
return cast(torch.Tensor, H)
class SetTransformer(torch.nn.Module):
def __init__(
self,
input_dim: int,
num_heads: int,
num_seeds: int = 1,
dropout: float = 0.0,
) -> None:
super().__init__()
self._encoder = Encoder(input_dim, num_heads, dropout)
self._decoder = Decoder(input_dim, num_heads, num_seeds, dropout)
self._output_dim = num_seeds * input_dim
def forward(
self,
X: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
Z = self._encoder(X, mask)
H = self._decoder(Z, mask)
return cast(torch.Tensor, H.view(X.size(0), -1))
def get_output_dim(self) -> int:
return self._output_dim
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment