Last active
June 17, 2021 15:37
-
-
Save dayyass/f7df77678495ff47ef92fa4a0ed4a429 to your computer and use it in GitHub Desktop.
My own implementation of Multihead Attention.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import torch.optim as optim\n", | |
"\n", | |
"from attention import ScaleDotProductAttention, MultiheadAttention" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# init global variables\n", | |
"BATCH_SIZE = 4\n", | |
"SEQ_LEN = 30\n", | |
"QUERY_DIM = 128\n", | |
"KEY_VALUE_DIM = 64\n", | |
"\n", | |
"HIDDEN_DIM = 64\n", | |
"NUM_HEADS = 8" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# init tensors to work with\n", | |
"query = torch.randn(BATCH_SIZE, SEQ_LEN, QUERY_DIM)\n", | |
"key = torch.randn(BATCH_SIZE, SEQ_LEN, KEY_VALUE_DIM)\n", | |
"value = torch.randn(BATCH_SIZE, SEQ_LEN, KEY_VALUE_DIM)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Scale Dot Product Attention (one head)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# init one attantion head\n", | |
"attn_layer = ScaleDotProductAttention(\n", | |
" query_dim=QUERY_DIM,\n", | |
" key_value_dim=KEY_VALUE_DIM,\n", | |
" hidden_dim=HIDDEN_DIM,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Number of learnable parameters: 16576\n" | |
] | |
} | |
], | |
"source": [ | |
"# count learnable parameters\n", | |
"n_params = attn_layer.n_params()\n", | |
"print(f\"Number of learnable parameters: {n_params}\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# pass our tensors through attention head\n", | |
"attn, attn_weights = attn_layer(\n", | |
" query=query,\n", | |
" key=key,\n", | |
" value=value,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(torch.Size([4, 30, 64]), torch.Size([4, 30, 30]))" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# validate shapes\n", | |
"attn.shape, attn_weights.shape" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Multihead Attention" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# init multihead attantion layer\n", | |
"mh_attn = MultiheadAttention(\n", | |
" query_dim=QUERY_DIM,\n", | |
" key_value_dim=KEY_VALUE_DIM,\n", | |
" hidden_dim=HIDDEN_DIM,\n", | |
" num_heads=NUM_HEADS,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Number of learnable parameters: 198272\n" | |
] | |
} | |
], | |
"source": [ | |
"# count learnable parameters\n", | |
"n_params = mh_attn.n_params()\n", | |
"print(f\"Number of learnable parameters: {n_params}\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# pass our tensor through multihead attantion layer\n", | |
"attn = mh_attn(\n", | |
" query=query,\n", | |
" key=key,\n", | |
" value=value,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(torch.Size([4, 30, 128]), torch.Size([4, 30, 128]))" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# validate shapes\n", | |
"query.shape, attn.shape" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Matrix norm minimization" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Let's solve small optimization problem:<br>\n", | |
"We want to minimize matrix Frobenius mean norm of multihead attantion layer output over batches adapting MultiheadAttention layer parameters." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# init global variables\n", | |
"N_EPOCHS = 10" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# init tensor to work with\n", | |
"X = torch.randn(BATCH_SIZE, SEQ_LEN, QUERY_DIM)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# init multihead attantion layer\n", | |
"mh_attn = MultiheadAttention(\n", | |
" query_dim=QUERY_DIM,\n", | |
" key_value_dim=QUERY_DIM,\n", | |
" hidden_dim=HIDDEN_DIM,\n", | |
" num_heads=NUM_HEADS,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# init criterion and optimizer\n", | |
"criterion = lambda x: torch.norm(x, dim=[1,2]).mean()\n", | |
"optimizer = optim.SGD(mh_attn.parameters(), lr=1e-3, momentum=0.8)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Matrix batch mean norm: 5.052\n" | |
] | |
} | |
], | |
"source": [ | |
"# check current metric\n", | |
"attn = mh_attn(\n", | |
" query=X,\n", | |
" key=X,\n", | |
" value=X,\n", | |
")\n", | |
"loss = criterion(attn)\n", | |
"print(f'Matrix batch mean norm: {round(loss.item(), 3)}')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"EPOCH: 1/10\n", | |
"Matrix batch mean norm: 5.052\n", | |
"\n", | |
"EPOCH: 2/10\n", | |
"Matrix batch mean norm: 4.956\n", | |
"\n", | |
"EPOCH: 3/10\n", | |
"Matrix batch mean norm: 4.785\n", | |
"\n", | |
"EPOCH: 4/10\n", | |
"Matrix batch mean norm: 4.558\n", | |
"\n", | |
"EPOCH: 5/10\n", | |
"Matrix batch mean norm: 4.29\n", | |
"\n", | |
"EPOCH: 6/10\n", | |
"Matrix batch mean norm: 3.997\n", | |
"\n", | |
"EPOCH: 7/10\n", | |
"Matrix batch mean norm: 3.689\n", | |
"\n", | |
"EPOCH: 8/10\n", | |
"Matrix batch mean norm: 3.379\n", | |
"\n", | |
"EPOCH: 9/10\n", | |
"Matrix batch mean norm: 3.076\n", | |
"\n", | |
"EPOCH: 10/10\n", | |
"Matrix batch mean norm: 2.789\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"# TRAIN\n", | |
"mh_attn.train()\n", | |
"\n", | |
"for i in range(N_EPOCHS):\n", | |
" \n", | |
" attn = mh_attn(\n", | |
" query=X,\n", | |
" key=X,\n", | |
" value=X,\n", | |
" )\n", | |
" \n", | |
" optimizer.zero_grad()\n", | |
" loss = criterion(attn)\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
" \n", | |
" print(f'EPOCH: {i+1}/{N_EPOCHS}')\n", | |
" print(f'Matrix batch mean norm: {round(loss.item(), 3)}')\n", | |
" print()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"As we can see, our metric decreases over epochs." | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.1" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from typing import Tuple | |
class ScaleDotProductAttention(nn.Module): | |
""" | |
My own implementation of Scale Dot Product Attention (one head) from paper: | |
https://arxiv.org/abs/1706.03762 | |
""" | |
def __init__( | |
self, | |
query_dim: int, | |
key_value_dim: int, | |
hidden_dim: int, | |
) -> None: | |
""" | |
Init ScaleDotProductAttention (one head). | |
:param int query_dim: query tensor embedding dimension. | |
:param int key_value_dim: key and value tensors embedding dimension. | |
:param int hidden_dim: hidden tensors dimension. | |
""" | |
super(ScaleDotProductAttention, self).__init__() | |
self.query_dim = query_dim | |
self.key_value_dim = key_value_dim | |
self.hidden_dim = hidden_dim | |
self.query_matrix = nn.Linear(query_dim, hidden_dim) | |
self.key_matrix = nn.Linear(key_value_dim, hidden_dim) | |
self.value_matrix = nn.Linear(key_value_dim, hidden_dim) | |
def forward( | |
self, | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Process query, key and value tensors. | |
:param torch.Tensor query: query tensor. | |
:param torch.Tensor key: key tensor. | |
:param torch.Tensor value: value tensor. | |
:return: attention_output and attention_output_weights. | |
:rtype: Tuple[torch.Tensor, torch.Tensor] | |
""" | |
Q = self.query_matrix(query) | |
K = self.key_matrix(key) | |
V = self.value_matrix(value) | |
# assert params | |
batch_size = query.shape[0] | |
seq_len = query.shape[1] | |
attn_shape = torch.Size([batch_size, seq_len, self.hidden_dim]) | |
attn_weights_shape = torch.Size([batch_size, seq_len, seq_len]) | |
assert Q.shape == attn_shape | |
assert K.shape == attn_shape | |
assert V.shape == attn_shape | |
QK = torch.bmm(Q, K.transpose(-1, -2)) | |
attn_weights = F.softmax(QK / (self.hidden_dim ** 0.5), dim=-1) | |
attn = torch.bmm(attn_weights, V) | |
assert attn_weights.shape == attn_weights_shape | |
assert attn.shape == attn_shape | |
return attn, attn_weights | |
def n_params(self) -> int: | |
""" | |
Get number of learnable parameters. | |
:return: number of learnable parameters. | |
:rtype: int | |
""" | |
return sum(p.numel() for p in self.parameters()) | |
class MultiheadAttention(nn.Module): | |
""" | |
My own implementation of Multihead Attention from paper: | |
https://arxiv.org/abs/1706.03762 | |
""" | |
def __init__( | |
self, | |
query_dim: int, | |
key_value_dim: int, | |
hidden_dim: int, | |
num_heads: int, | |
) -> None: | |
""" | |
Init MultiheadAttention. | |
:param int query_dim: query tensor embedding dimension. | |
:param int key_value_dim: key and value tensors embedding dimension. | |
:param int hidden_dim: hidden tensors dimension. | |
:param int num_heads: number of attention heads (ScaleDotProductAttention). | |
""" | |
super(MultiheadAttention, self).__init__() | |
self.query_dim = query_dim | |
self.key_value_dim = key_value_dim | |
self.hidden_dim = hidden_dim | |
self.num_heads = num_heads | |
self.attn_heads = nn.ModuleList() | |
for _ in range(num_heads): | |
self.attn_heads.append( | |
ScaleDotProductAttention( | |
query_dim=query_dim, | |
key_value_dim=key_value_dim, | |
hidden_dim=hidden_dim, | |
) | |
) | |
self.multihead_matrix = nn.Linear(num_heads * hidden_dim, query_dim) | |
def forward( | |
self, | |
query: torch.Tensor, | |
key: torch.Tensor, | |
value: torch.Tensor, | |
) -> torch.Tensor: | |
""" | |
Process query, key and value tensors. | |
:param torch.Tensor query: query tensor. | |
:param torch.Tensor key: key tensor. | |
:param torch.Tensor value: value tensor. | |
:return: attention_output and attention_output_weights. | |
:rtype: Tuple[torch.Tensor, torch.Tensor] | |
""" | |
heads_output = [] | |
for head in self.attn_heads: | |
attn_head, _ = head( | |
query=query, | |
key=key, | |
value=value, | |
) | |
heads_output.append(attn_head) | |
attn_cat = torch.cat(heads_output, dim=-1) | |
attn = self.multihead_matrix(attn_cat) | |
# assert params | |
batch_size = query.shape[0] | |
seq_len = query.shape[1] | |
assert attn_cat.shape == torch.Size([batch_size, seq_len, self.num_heads * self.hidden_dim]) | |
assert attn.shape == torch.Size([batch_size, seq_len, self.query_dim]) | |
return attn | |
def n_params(self): | |
""" | |
Get number of learnable parameters. | |
:return: number of learnable parameters. | |
:rtype: int | |
""" | |
return sum(p.numel() for p in self.parameters()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment