Skip to content

Instantly share code, notes, and snippets.

@dayyass
Last active June 17, 2021 15:37
Show Gist options
  • Save dayyass/f7df77678495ff47ef92fa4a0ed4a429 to your computer and use it in GitHub Desktop.
Save dayyass/f7df77678495ff47ef92fa4a0ed4a429 to your computer and use it in GitHub Desktop.
My own implementation of Multihead Attention.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.optim as optim\n",
"\n",
"from attention import ScaleDotProductAttention, MultiheadAttention"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# init global variables\n",
"BATCH_SIZE = 4\n",
"SEQ_LEN = 30\n",
"QUERY_DIM = 128\n",
"KEY_VALUE_DIM = 64\n",
"\n",
"HIDDEN_DIM = 64\n",
"NUM_HEADS = 8"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# init tensors to work with\n",
"query = torch.randn(BATCH_SIZE, SEQ_LEN, QUERY_DIM)\n",
"key = torch.randn(BATCH_SIZE, SEQ_LEN, KEY_VALUE_DIM)\n",
"value = torch.randn(BATCH_SIZE, SEQ_LEN, KEY_VALUE_DIM)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Scale Dot Product Attention (one head)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# init one attantion head\n",
"attn_layer = ScaleDotProductAttention(\n",
" query_dim=QUERY_DIM,\n",
" key_value_dim=KEY_VALUE_DIM,\n",
" hidden_dim=HIDDEN_DIM,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of learnable parameters: 16576\n"
]
}
],
"source": [
"# count learnable parameters\n",
"n_params = attn_layer.n_params()\n",
"print(f\"Number of learnable parameters: {n_params}\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# pass our tensors through attention head\n",
"attn, attn_weights = attn_layer(\n",
" query=query,\n",
" key=key,\n",
" value=value,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([4, 30, 64]), torch.Size([4, 30, 30]))"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# validate shapes\n",
"attn.shape, attn_weights.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Multihead Attention"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# init multihead attantion layer\n",
"mh_attn = MultiheadAttention(\n",
" query_dim=QUERY_DIM,\n",
" key_value_dim=KEY_VALUE_DIM,\n",
" hidden_dim=HIDDEN_DIM,\n",
" num_heads=NUM_HEADS,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of learnable parameters: 198272\n"
]
}
],
"source": [
"# count learnable parameters\n",
"n_params = mh_attn.n_params()\n",
"print(f\"Number of learnable parameters: {n_params}\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# pass our tensor through multihead attantion layer\n",
"attn = mh_attn(\n",
" query=query,\n",
" key=key,\n",
" value=value,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([4, 30, 128]), torch.Size([4, 30, 128]))"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# validate shapes\n",
"query.shape, attn.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Matrix norm minimization"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's solve small optimization problem:<br>\n",
"We want to minimize matrix Frobenius mean norm of multihead attantion layer output over batches adapting MultiheadAttention layer parameters."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# init global variables\n",
"N_EPOCHS = 10"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"# init tensor to work with\n",
"X = torch.randn(BATCH_SIZE, SEQ_LEN, QUERY_DIM)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"# init multihead attantion layer\n",
"mh_attn = MultiheadAttention(\n",
" query_dim=QUERY_DIM,\n",
" key_value_dim=QUERY_DIM,\n",
" hidden_dim=HIDDEN_DIM,\n",
" num_heads=NUM_HEADS,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"# init criterion and optimizer\n",
"criterion = lambda x: torch.norm(x, dim=[1,2]).mean()\n",
"optimizer = optim.SGD(mh_attn.parameters(), lr=1e-3, momentum=0.8)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Matrix batch mean norm: 5.052\n"
]
}
],
"source": [
"# check current metric\n",
"attn = mh_attn(\n",
" query=X,\n",
" key=X,\n",
" value=X,\n",
")\n",
"loss = criterion(attn)\n",
"print(f'Matrix batch mean norm: {round(loss.item(), 3)}')"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"EPOCH: 1/10\n",
"Matrix batch mean norm: 5.052\n",
"\n",
"EPOCH: 2/10\n",
"Matrix batch mean norm: 4.956\n",
"\n",
"EPOCH: 3/10\n",
"Matrix batch mean norm: 4.785\n",
"\n",
"EPOCH: 4/10\n",
"Matrix batch mean norm: 4.558\n",
"\n",
"EPOCH: 5/10\n",
"Matrix batch mean norm: 4.29\n",
"\n",
"EPOCH: 6/10\n",
"Matrix batch mean norm: 3.997\n",
"\n",
"EPOCH: 7/10\n",
"Matrix batch mean norm: 3.689\n",
"\n",
"EPOCH: 8/10\n",
"Matrix batch mean norm: 3.379\n",
"\n",
"EPOCH: 9/10\n",
"Matrix batch mean norm: 3.076\n",
"\n",
"EPOCH: 10/10\n",
"Matrix batch mean norm: 2.789\n",
"\n"
]
}
],
"source": [
"# TRAIN\n",
"mh_attn.train()\n",
"\n",
"for i in range(N_EPOCHS):\n",
" \n",
" attn = mh_attn(\n",
" query=X,\n",
" key=X,\n",
" value=X,\n",
" )\n",
" \n",
" optimizer.zero_grad()\n",
" loss = criterion(attn)\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" print(f'EPOCH: {i+1}/{N_EPOCHS}')\n",
" print(f'Matrix batch mean norm: {round(loss.item(), 3)}')\n",
" print()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we can see, our metric decreases over epochs."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
class ScaleDotProductAttention(nn.Module):
"""
My own implementation of Scale Dot Product Attention (one head) from paper:
https://arxiv.org/abs/1706.03762
"""
def __init__(
self,
query_dim: int,
key_value_dim: int,
hidden_dim: int,
) -> None:
"""
Init ScaleDotProductAttention (one head).
:param int query_dim: query tensor embedding dimension.
:param int key_value_dim: key and value tensors embedding dimension.
:param int hidden_dim: hidden tensors dimension.
"""
super(ScaleDotProductAttention, self).__init__()
self.query_dim = query_dim
self.key_value_dim = key_value_dim
self.hidden_dim = hidden_dim
self.query_matrix = nn.Linear(query_dim, hidden_dim)
self.key_matrix = nn.Linear(key_value_dim, hidden_dim)
self.value_matrix = nn.Linear(key_value_dim, hidden_dim)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Process query, key and value tensors.
:param torch.Tensor query: query tensor.
:param torch.Tensor key: key tensor.
:param torch.Tensor value: value tensor.
:return: attention_output and attention_output_weights.
:rtype: Tuple[torch.Tensor, torch.Tensor]
"""
Q = self.query_matrix(query)
K = self.key_matrix(key)
V = self.value_matrix(value)
# assert params
batch_size = query.shape[0]
seq_len = query.shape[1]
attn_shape = torch.Size([batch_size, seq_len, self.hidden_dim])
attn_weights_shape = torch.Size([batch_size, seq_len, seq_len])
assert Q.shape == attn_shape
assert K.shape == attn_shape
assert V.shape == attn_shape
QK = torch.bmm(Q, K.transpose(-1, -2))
attn_weights = F.softmax(QK / (self.hidden_dim ** 0.5), dim=-1)
attn = torch.bmm(attn_weights, V)
assert attn_weights.shape == attn_weights_shape
assert attn.shape == attn_shape
return attn, attn_weights
def n_params(self) -> int:
"""
Get number of learnable parameters.
:return: number of learnable parameters.
:rtype: int
"""
return sum(p.numel() for p in self.parameters())
class MultiheadAttention(nn.Module):
"""
My own implementation of Multihead Attention from paper:
https://arxiv.org/abs/1706.03762
"""
def __init__(
self,
query_dim: int,
key_value_dim: int,
hidden_dim: int,
num_heads: int,
) -> None:
"""
Init MultiheadAttention.
:param int query_dim: query tensor embedding dimension.
:param int key_value_dim: key and value tensors embedding dimension.
:param int hidden_dim: hidden tensors dimension.
:param int num_heads: number of attention heads (ScaleDotProductAttention).
"""
super(MultiheadAttention, self).__init__()
self.query_dim = query_dim
self.key_value_dim = key_value_dim
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.attn_heads = nn.ModuleList()
for _ in range(num_heads):
self.attn_heads.append(
ScaleDotProductAttention(
query_dim=query_dim,
key_value_dim=key_value_dim,
hidden_dim=hidden_dim,
)
)
self.multihead_matrix = nn.Linear(num_heads * hidden_dim, query_dim)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> torch.Tensor:
"""
Process query, key and value tensors.
:param torch.Tensor query: query tensor.
:param torch.Tensor key: key tensor.
:param torch.Tensor value: value tensor.
:return: attention_output and attention_output_weights.
:rtype: Tuple[torch.Tensor, torch.Tensor]
"""
heads_output = []
for head in self.attn_heads:
attn_head, _ = head(
query=query,
key=key,
value=value,
)
heads_output.append(attn_head)
attn_cat = torch.cat(heads_output, dim=-1)
attn = self.multihead_matrix(attn_cat)
# assert params
batch_size = query.shape[0]
seq_len = query.shape[1]
assert attn_cat.shape == torch.Size([batch_size, seq_len, self.num_heads * self.hidden_dim])
assert attn.shape == torch.Size([batch_size, seq_len, self.query_dim])
return attn
def n_params(self):
"""
Get number of learnable parameters.
:return: number of learnable parameters.
:rtype: int
"""
return sum(p.numel() for p in self.parameters())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment