Skip to content

Instantly share code, notes, and snippets.

@yzhangcs
Created August 22, 2022 10:00
Show Gist options
  • Save yzhangcs/22946371f72480668b84b14a45353c80 to your computer and use it in GitHub Desktop.
Save yzhangcs/22946371f72480668b84b14a45353c80 to your computer and use it in GitHub Desktop.
Relation aware Transformer
# -*- coding: utf-8 -*-
from __future__ import annotations
import copy
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
class RelationAwareTransformerEncoder(nn.Module):
def __init__(
self,
layer: nn.Module,
n_layers: int = 6,
n_model: int = 1024,
pre_norm: bool = False,
) -> RelationAwareTransformerEncoder:
super(RelationAwareTransformerEncoder, self).__init__()
self.n_layers = n_layers
self.n_model = n_model
self.pre_norm = pre_norm
self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layers)])
self.norm = nn.LayerNorm(n_model) if self.pre_norm else None
def forward(self, x: torch.Tensor, rels: torch.LongTensor, mask: torch.BoolTensor) -> torch.Tensor:
x = x.transpose(0, 1)
for layer in self.layers:
x = layer(x, rels, mask)
if self.pre_norm:
x = self.norm(x)
return x.transpose(0, 1)
class RelationAwareTransformerEncoderLayer(nn.Module):
def __init__(
self,
n_rels: int,
n_heads: int = 8,
n_model: int = 1024,
n_inner: int = 2048,
activation: str = 'relu',
pre_norm: bool = False,
attn_dropout: float = 0.1,
ffn_dropout: float = 0.1,
dropout: float = 0.1,
) -> RelationAwareTransformerEncoderLayer:
super(RelationAwareTransformerEncoderLayer, self).__init__()
self.attn = RelationAwareMultiHeadAttention(n_rels=n_rels,
n_heads=n_heads,
n_model=n_model,
n_embed=n_model//n_heads,
dropout=attn_dropout)
self.attn_norm = nn.LayerNorm(n_model)
self.ffn = PositionwiseFeedForward(n_model=n_model,
n_inner=n_inner,
activation=activation,
dropout=ffn_dropout)
self.ffn_norm = nn.LayerNorm(n_model)
self.dropout = nn.Dropout(dropout)
self.pre_norm = pre_norm
def forward(self, x: torch.Tensor, rels: torch.LongTensor, mask: torch.BoolTensor) -> torch.Tensor:
if self.pre_norm:
n = self.attn_norm(x)
x = x + self.dropout(self.attn(n, n, n, rels, mask))
n = self.ffn_norm(x)
x = x + self.dropout(self.ffn(n))
else:
x = self.attn_norm(x + self.dropout(self.attn(x, x, x, rels, mask)))
x = self.ffn_norm(x + self.dropout(self.ffn(x)))
return x
class RelationAwareMultiHeadAttention(nn.Module):
def __init__(
self,
n_rels: int,
n_heads: int = 8,
n_model: int = 1024,
n_embed: int = 128,
dropout: float = 0.1,
attn: bool = False
) -> RelationAwareMultiHeadAttention:
super(RelationAwareMultiHeadAttention, self).__init__()
self.n_rels = n_rels
self.n_heads = n_heads
self.n_model = n_model
self.n_embed = n_embed
self.scale = n_embed**0.5
self.rel_embed = nn.Embedding(num_embeddings=n_rels, embedding_dim=n_embed)
self.wq = nn.Parameter(torch.zeros(n_model, n_heads * n_embed))
self.wk = nn.Parameter(torch.zeros(n_model, n_heads * n_embed))
self.wv = nn.Parameter(torch.zeros(n_model, n_heads * n_embed))
self.wo = nn.Parameter(torch.zeros(n_heads * n_embed, n_model))
self.bu = nn.Parameter(torch.zeros(n_heads, n_embed))
self.bv = nn.Parameter(torch.zeros(n_heads, n_embed))
self.dropout = nn.Dropout(dropout)
self.attn = attn
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.rel_embed.weight, 2 ** -0.5)
# borrowed from https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/multihead_attention.py
nn.init.xavier_uniform_(self.wq, 2 ** -0.5)
nn.init.xavier_uniform_(self.wk, 2 ** -0.5)
nn.init.xavier_uniform_(self.wv, 2 ** -0.5)
nn.init.xavier_uniform_(self.wo)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
rels: torch.LongTensor,
mask: torch.BoolTensor,
attn_mask: Optional[torch.BoolTensor] = None
) -> torch.Tensor:
batch_size, _ = mask.shape
# [seq_len, batch_size, n_heads, n_embed]
q = F.linear(q, self.wq).view(-1, batch_size, self.n_heads, self.n_embed)
# [src_len, batch_size * n_heads, n_embed]
k = F.linear(k, self.wk).view(-1, batch_size * self.n_heads, self.n_embed)
v = F.linear(v, self.wv).view(-1, batch_size * self.n_heads, self.n_embed)
# [seq_len, batch_size * n_heads, n_embed]
qu = (q + self.bu).view(-1, *k.shape[1:])
# [seq_len, batch_size, n_heads, n_embed]
qv = q + self.bv
rel_mask = rels.ge(0) & mask.unsqueeze(1) & mask.unsqueeze(2)
rel_indices = torch.where(rel_mask)
mask = mask.unsqueeze(1).repeat(1, self.n_heads, 1).view(-1, 1, *mask.shape[1:])
if attn_mask is not None:
mask = mask & attn_mask
# [batch_size * n_heads, seq_len, src_len]
attn = torch.bmm(qu.transpose(0, 1), k.movedim((0, 1), (2, 0)))
# [seq_len, batch_size * n_heads, n_embed]
rel_attn = torch.bmm(qv[rel_indices[1], rel_indices[0]], self.rel_embed(rels[rel_mask]).unsqueeze(-1)).squeeze(-1)
# [batch_size, seq_len, src_len, n_heads]
rel_attn = attn.new_zeros(batch_size, *attn.shape[-2:], self.n_heads).masked_scatter_(rel_mask.unsqueeze(-1), rel_attn)
attn = attn + rel_attn.movedim(-1, 1).reshape_as(attn)
attn = torch.softmax(attn / self.scale + torch.where(mask, 0., float('-inf')), -1)
attn = self.dropout(attn)
# [seq_len, batch_size * n_heads, n_embed]
x = torch.bmm(attn, v.transpose(0, 1)).transpose(0, 1)
# [seq_len, batch_size, n_model]
x = F.linear(x.reshape(-1, batch_size, self.n_heads * self.n_embed), self.wo)
return (x, attn.view(batch_size, self.n_heads, *attn.shape[1:])) if self.attn else x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment