Last active
June 17, 2021 15:37
-
-
Save dayyass/f7df77678495ff47ef92fa4a0ed4a429 to your computer and use it in GitHub Desktop.
My own implementation of Multihead Attention.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from typing import Tuple | |
class ScaleDotProductAttention(nn.Module): | |
""" | |
My own implementation of Scale Dot Product Attention (one head) from paper: | |
https://arxiv.org/abs/1706.03762 | |
""" | |
def __init__( | |
self, | |
query_dim: int, | |
key_value_dim: int, | |
hidden_dim: int, | |
) -> None: | |
""" | |
Init ScaleDotProductAttention (one head). | |
:param int query_dim: query tensor embedding dimension. | |
:param int key_value_dim: key and value tensors embedding dimension. | |
:param int hidden_dim: hidden tensors dimension. | |
""" | |
super(ScaleDotProductAttention, self).__init__() | |
self.query_dim = query_dim | |
self.key_value_dim = key_value_dim | |
self.hidden_dim = hidden_dim | |
self.query_matrix = nn.Linear(query_dim, hidden_dim) | |
self.key_matrix = nn.Linear(key_value_dim, hidden_dim) | |
self.value_matrix = nn.Linear(key_value_dim, hidden_dim) | |
def forward( | |
self, | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Process query, key and value tensors. | |
:param torch.Tensor query: query tensor. | |
:param torch.Tensor key: key tensor. | |
:param torch.Tensor value: value tensor. | |
:return: attention_output and attention_output_weights. | |
:rtype: Tuple[torch.Tensor, torch.Tensor] | |
""" | |
Q = self.query_matrix(query) | |
K = self.key_matrix(key) | |
V = self.value_matrix(value) | |
# assert params | |
batch_size = query.shape[0] | |
seq_len = query.shape[1] | |
attn_shape = torch.Size([batch_size, seq_len, self.hidden_dim]) | |
attn_weights_shape = torch.Size([batch_size, seq_len, seq_len]) | |
assert Q.shape == attn_shape | |
assert K.shape == attn_shape | |
assert V.shape == attn_shape | |
QK = torch.bmm(Q, K.transpose(-1, -2)) | |
attn_weights = F.softmax(QK / (self.hidden_dim ** 0.5), dim=-1) | |
attn = torch.bmm(attn_weights, V) | |
assert attn_weights.shape == attn_weights_shape | |
assert attn.shape == attn_shape | |
return attn, attn_weights | |
def n_params(self) -> int: | |
""" | |
Get number of learnable parameters. | |
:return: number of learnable parameters. | |
:rtype: int | |
""" | |
return sum(p.numel() for p in self.parameters()) | |
class MultiheadAttention(nn.Module): | |
""" | |
My own implementation of Multihead Attention from paper: | |
https://arxiv.org/abs/1706.03762 | |
""" | |
def __init__( | |
self, | |
query_dim: int, | |
key_value_dim: int, | |
hidden_dim: int, | |
num_heads: int, | |
) -> None: | |
""" | |
Init MultiheadAttention. | |
:param int query_dim: query tensor embedding dimension. | |
:param int key_value_dim: key and value tensors embedding dimension. | |
:param int hidden_dim: hidden tensors dimension. | |
:param int num_heads: number of attention heads (ScaleDotProductAttention). | |
""" | |
super(MultiheadAttention, self).__init__() | |
self.query_dim = query_dim | |
self.key_value_dim = key_value_dim | |
self.hidden_dim = hidden_dim | |
self.num_heads = num_heads | |
self.attn_heads = nn.ModuleList() | |
for _ in range(num_heads): | |
self.attn_heads.append( | |
ScaleDotProductAttention( | |
query_dim=query_dim, | |
key_value_dim=key_value_dim, | |
hidden_dim=hidden_dim, | |
) | |
) | |
self.multihead_matrix = nn.Linear(num_heads * hidden_dim, query_dim) | |
def forward( | |
self, | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
) -> torch.Tensor: | |
""" | |
Process query, key and value tensors. | |
:param torch.Tensor query: query tensor. | |
:param torch.Tensor key: key tensor. | |
:param torch.Tensor value: value tensor. | |
:return: attention_output and attention_output_weights. | |
:rtype: Tuple[torch.Tensor, torch.Tensor] | |
""" | |
heads_output = [] | |
for head in self.attn_heads: | |
attn_head, _ = head( | |
query=query, | |
key=key, | |
value=value, | |
) | |
heads_output.append(attn_head) | |
attn_cat = torch.cat(heads_output, dim=-1) | |
attn = self.multihead_matrix(attn_cat) | |
# assert params | |
batch_size = query.shape[0] | |
seq_len = query.shape[1] | |
assert attn_cat.shape == torch.Size([batch_size, seq_len, self.num_heads * self.hidden_dim]) | |
assert attn.shape == torch.Size([batch_size, seq_len, self.query_dim]) | |
return attn | |
def n_params(self): | |
""" | |
Get number of learnable parameters. | |
:return: number of learnable parameters. | |
:rtype: int | |
""" | |
return sum(p.numel() for p in self.parameters()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment