Created
August 22, 2022 10:00
-
-
Save yzhangcs/22946371f72480668b84b14a45353c80 to your computer and use it in GitHub Desktop.
Relation aware Transformer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
from __future__ import annotations | |
import copy | |
from typing import Optional | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class RelationAwareTransformerEncoder(nn.Module): | |
def __init__( | |
self, | |
layer: nn.Module, | |
n_layers: int = 6, | |
n_model: int = 1024, | |
pre_norm: bool = False, | |
) -> RelationAwareTransformerEncoder: | |
super(RelationAwareTransformerEncoder, self).__init__() | |
self.n_layers = n_layers | |
self.n_model = n_model | |
self.pre_norm = pre_norm | |
self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layers)]) | |
self.norm = nn.LayerNorm(n_model) if self.pre_norm else None | |
def forward(self, x: torch.Tensor, rels: torch.LongTensor, mask: torch.BoolTensor) -> torch.Tensor: | |
x = x.transpose(0, 1) | |
for layer in self.layers: | |
x = layer(x, rels, mask) | |
if self.pre_norm: | |
x = self.norm(x) | |
return x.transpose(0, 1) | |
class RelationAwareTransformerEncoderLayer(nn.Module): | |
def __init__( | |
self, | |
n_rels: int, | |
n_heads: int = 8, | |
n_model: int = 1024, | |
n_inner: int = 2048, | |
activation: str = 'relu', | |
pre_norm: bool = False, | |
attn_dropout: float = 0.1, | |
ffn_dropout: float = 0.1, | |
dropout: float = 0.1, | |
) -> RelationAwareTransformerEncoderLayer: | |
super(RelationAwareTransformerEncoderLayer, self).__init__() | |
self.attn = RelationAwareMultiHeadAttention(n_rels=n_rels, | |
n_heads=n_heads, | |
n_model=n_model, | |
n_embed=n_model//n_heads, | |
dropout=attn_dropout) | |
self.attn_norm = nn.LayerNorm(n_model) | |
self.ffn = PositionwiseFeedForward(n_model=n_model, | |
n_inner=n_inner, | |
activation=activation, | |
dropout=ffn_dropout) | |
self.ffn_norm = nn.LayerNorm(n_model) | |
self.dropout = nn.Dropout(dropout) | |
self.pre_norm = pre_norm | |
def forward(self, x: torch.Tensor, rels: torch.LongTensor, mask: torch.BoolTensor) -> torch.Tensor: | |
if self.pre_norm: | |
n = self.attn_norm(x) | |
x = x + self.dropout(self.attn(n, n, n, rels, mask)) | |
n = self.ffn_norm(x) | |
x = x + self.dropout(self.ffn(n)) | |
else: | |
x = self.attn_norm(x + self.dropout(self.attn(x, x, x, rels, mask))) | |
x = self.ffn_norm(x + self.dropout(self.ffn(x))) | |
return x | |
class RelationAwareMultiHeadAttention(nn.Module): | |
def __init__( | |
self, | |
n_rels: int, | |
n_heads: int = 8, | |
n_model: int = 1024, | |
n_embed: int = 128, | |
dropout: float = 0.1, | |
attn: bool = False | |
) -> RelationAwareMultiHeadAttention: | |
super(RelationAwareMultiHeadAttention, self).__init__() | |
self.n_rels = n_rels | |
self.n_heads = n_heads | |
self.n_model = n_model | |
self.n_embed = n_embed | |
self.scale = n_embed**0.5 | |
self.rel_embed = nn.Embedding(num_embeddings=n_rels, embedding_dim=n_embed) | |
self.wq = nn.Parameter(torch.zeros(n_model, n_heads * n_embed)) | |
self.wk = nn.Parameter(torch.zeros(n_model, n_heads * n_embed)) | |
self.wv = nn.Parameter(torch.zeros(n_model, n_heads * n_embed)) | |
self.wo = nn.Parameter(torch.zeros(n_heads * n_embed, n_model)) | |
self.bu = nn.Parameter(torch.zeros(n_heads, n_embed)) | |
self.bv = nn.Parameter(torch.zeros(n_heads, n_embed)) | |
self.dropout = nn.Dropout(dropout) | |
self.attn = attn | |
self.reset_parameters() | |
def reset_parameters(self): | |
nn.init.xavier_uniform_(self.rel_embed.weight, 2 ** -0.5) | |
# borrowed from https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/multihead_attention.py | |
nn.init.xavier_uniform_(self.wq, 2 ** -0.5) | |
nn.init.xavier_uniform_(self.wk, 2 ** -0.5) | |
nn.init.xavier_uniform_(self.wv, 2 ** -0.5) | |
nn.init.xavier_uniform_(self.wo) | |
def forward( | |
self, | |
q: torch.Tensor, | |
k: torch.Tensor, | |
v: torch.Tensor, | |
rels: torch.LongTensor, | |
mask: torch.BoolTensor, | |
attn_mask: Optional[torch.BoolTensor] = None | |
) -> torch.Tensor: | |
batch_size, _ = mask.shape | |
# [seq_len, batch_size, n_heads, n_embed] | |
q = F.linear(q, self.wq).view(-1, batch_size, self.n_heads, self.n_embed) | |
# [src_len, batch_size * n_heads, n_embed] | |
k = F.linear(k, self.wk).view(-1, batch_size * self.n_heads, self.n_embed) | |
v = F.linear(v, self.wv).view(-1, batch_size * self.n_heads, self.n_embed) | |
# [seq_len, batch_size * n_heads, n_embed] | |
qu = (q + self.bu).view(-1, *k.shape[1:]) | |
# [seq_len, batch_size, n_heads, n_embed] | |
qv = q + self.bv | |
rel_mask = rels.ge(0) & mask.unsqueeze(1) & mask.unsqueeze(2) | |
rel_indices = torch.where(rel_mask) | |
mask = mask.unsqueeze(1).repeat(1, self.n_heads, 1).view(-1, 1, *mask.shape[1:]) | |
if attn_mask is not None: | |
mask = mask & attn_mask | |
# [batch_size * n_heads, seq_len, src_len] | |
attn = torch.bmm(qu.transpose(0, 1), k.movedim((0, 1), (2, 0))) | |
# [seq_len, batch_size * n_heads, n_embed] | |
rel_attn = torch.bmm(qv[rel_indices[1], rel_indices[0]], self.rel_embed(rels[rel_mask]).unsqueeze(-1)).squeeze(-1) | |
# [batch_size, seq_len, src_len, n_heads] | |
rel_attn = attn.new_zeros(batch_size, *attn.shape[-2:], self.n_heads).masked_scatter_(rel_mask.unsqueeze(-1), rel_attn) | |
attn = attn + rel_attn.movedim(-1, 1).reshape_as(attn) | |
attn = torch.softmax(attn / self.scale + torch.where(mask, 0., float('-inf')), -1) | |
attn = self.dropout(attn) | |
# [seq_len, batch_size * n_heads, n_embed] | |
x = torch.bmm(attn, v.transpose(0, 1)).transpose(0, 1) | |
# [seq_len, batch_size, n_model] | |
x = F.linear(x.reshape(-1, batch_size, self.n_heads * self.n_embed), self.wo) | |
return (x, attn.view(batch_size, self.n_heads, *attn.shape[1:])) if self.attn else x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment