Skip to content

Instantly share code, notes, and snippets.

@dayyass
Last active June 17, 2021 15:37
Show Gist options
  • Save dayyass/f7df77678495ff47ef92fa4a0ed4a429 to your computer and use it in GitHub Desktop.
Save dayyass/f7df77678495ff47ef92fa4a0ed4a429 to your computer and use it in GitHub Desktop.
My own implementation of Multihead Attention.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
class ScaleDotProductAttention(nn.Module):
"""
My own implementation of Scale Dot Product Attention (one head) from paper:
https://arxiv.org/abs/1706.03762
"""
def __init__(
self,
query_dim: int,
key_value_dim: int,
hidden_dim: int,
) -> None:
"""
Init ScaleDotProductAttention (one head).
:param int query_dim: query tensor embedding dimension.
:param int key_value_dim: key and value tensors embedding dimension.
:param int hidden_dim: hidden tensors dimension.
"""
super(ScaleDotProductAttention, self).__init__()
self.query_dim = query_dim
self.key_value_dim = key_value_dim
self.hidden_dim = hidden_dim
self.query_matrix = nn.Linear(query_dim, hidden_dim)
self.key_matrix = nn.Linear(key_value_dim, hidden_dim)
self.value_matrix = nn.Linear(key_value_dim, hidden_dim)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Process query, key and value tensors.
:param torch.Tensor query: query tensor.
:param torch.Tensor key: key tensor.
:param torch.Tensor value: value tensor.
:return: attention_output and attention_output_weights.
:rtype: Tuple[torch.Tensor, torch.Tensor]
"""
Q = self.query_matrix(query)
K = self.key_matrix(key)
V = self.value_matrix(value)
# assert params
batch_size = query.shape[0]
seq_len = query.shape[1]
attn_shape = torch.Size([batch_size, seq_len, self.hidden_dim])
attn_weights_shape = torch.Size([batch_size, seq_len, seq_len])
assert Q.shape == attn_shape
assert K.shape == attn_shape
assert V.shape == attn_shape
QK = torch.bmm(Q, K.transpose(-1, -2))
attn_weights = F.softmax(QK / (self.hidden_dim ** 0.5), dim=-1)
attn = torch.bmm(attn_weights, V)
assert attn_weights.shape == attn_weights_shape
assert attn.shape == attn_shape
return attn, attn_weights
def n_params(self) -> int:
"""
Get number of learnable parameters.
:return: number of learnable parameters.
:rtype: int
"""
return sum(p.numel() for p in self.parameters())
class MultiheadAttention(nn.Module):
"""
My own implementation of Multihead Attention from paper:
https://arxiv.org/abs/1706.03762
"""
def __init__(
self,
query_dim: int,
key_value_dim: int,
hidden_dim: int,
num_heads: int,
) -> None:
"""
Init MultiheadAttention.
:param int query_dim: query tensor embedding dimension.
:param int key_value_dim: key and value tensors embedding dimension.
:param int hidden_dim: hidden tensors dimension.
:param int num_heads: number of attention heads (ScaleDotProductAttention).
"""
super(MultiheadAttention, self).__init__()
self.query_dim = query_dim
self.key_value_dim = key_value_dim
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.attn_heads = nn.ModuleList()
for _ in range(num_heads):
self.attn_heads.append(
ScaleDotProductAttention(
query_dim=query_dim,
key_value_dim=key_value_dim,
hidden_dim=hidden_dim,
)
)
self.multihead_matrix = nn.Linear(num_heads * hidden_dim, query_dim)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> torch.Tensor:
"""
Process query, key and value tensors.
:param torch.Tensor query: query tensor.
:param torch.Tensor key: key tensor.
:param torch.Tensor value: value tensor.
:return: attention_output and attention_output_weights.
:rtype: Tuple[torch.Tensor, torch.Tensor]
"""
heads_output = []
for head in self.attn_heads:
attn_head, _ = head(
query=query,
key=key,
value=value,
)
heads_output.append(attn_head)
attn_cat = torch.cat(heads_output, dim=-1)
attn = self.multihead_matrix(attn_cat)
# assert params
batch_size = query.shape[0]
seq_len = query.shape[1]
assert attn_cat.shape == torch.Size([batch_size, seq_len, self.num_heads * self.hidden_dim])
assert attn.shape == torch.Size([batch_size, seq_len, self.query_dim])
return attn
def n_params(self):
"""
Get number of learnable parameters.
:return: number of learnable parameters.
:rtype: int
"""
return sum(p.numel() for p in self.parameters())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment