Skip to content

Instantly share code, notes, and snippets.

@jonghwanhyeon
Created August 28, 2024 15:27
Show Gist options
  • Save jonghwanhyeon/f4315cfed1ec1e521c6607622ea8860f to your computer and use it in GitHub Desktop.
Save jonghwanhyeon/f4315cfed1ec1e521c6607622ea8860f to your computer and use it in GitHub Desktop.
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def scaled_dot_product_attention(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
# query: (batch, target_length, d_model)
# key: (batch, source_length, d_model)
# value: (batch, source_length, d_model)
# mask: (batch, target_length, source_length)
scale = np.sqrt(query.size(-1))
# (batch, target_length, d_model) @ (batch, d_model, source_length)
# -> (batch, target_length, source_length)
score = query @ key.transpose(1, 2) / scale
if mask is not None:
score.masked_fill_(mask, float("-inf"))
# (batch, target_length, source_length)
attention = F.softmax(score, dim=-1)
if mask is not None:
attention = attention.masked_fill(mask, 0.0)
# (batch, target_length, source_length) @ (batch, source_length, d_model)
# -> (batch, target_length, d_model)
context = attention @ value
return context, attention
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_head = d_model // num_heads
assert (self.d_head * num_heads) == self.d_model, "d_model must be divisible by num_heads"
# Wq, Wk, Wv
# ┌───────────────┐
# │ H │ H │ H │ H │
# │ E │ E │ E │ E │
# │ A │ A │ A │ A │ (d_model, num_head * d_head)
# │ D │ D │ D │ D │
# │ 1 │ 2 │ 3 │ 4 │
# └───────────────┘
# ↑
# d_head
self.W_q = nn.Linear(d_model, self.num_heads * self.d_head)
self.W_k = nn.Linear(d_model, self.num_heads * self.d_head)
self.W_v = nn.Linear(d_model, self.num_heads * self.d_head)
self.W_o = nn.Linear(self.num_heads * self.d_head, d_model)
def forward(
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, padding_mask: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
# query: (batch, target_length, d_model)
# key: (batch, source_length, d_model)
# value: (batch, source_length, d_model)
batch_size, target_length, source_length = query.size(0), query.size(1), key.size(1)
# (batch, target_length, num_heads, d_head)
query = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_head)
# (batch, source_length, num_heads,d_head)
key = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_head)
# (batch, source_length, num_heads,d_head)
value = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_head)
# (batch, sequence/target_length, num_heads, d_head)
# -> (batch, num_heads, sequence/target_length, d_head)
# -> (batch * num_heads, sequence/target_length, d_head)
query = query.permute(0, 2, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head)
key = key.permute(0, 2, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head)
value = value.permute(0, 2, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head)
attention_mask = None
if padding_mask is not None:
attention_mask = (
padding_mask.view(batch_size, 1, 1, source_length)
.expand(-1, self.num_heads, -1, -1)
.reshape(batch_size * self.num_heads, 1, source_length)
)
# context: (batch * num_heads, target_length, d_head)
# attention: (batch * num_heads, target_length, source_length)
context, attention = scaled_dot_product_attention(query, key, value, attention_mask)
attention = attention.view(batch_size, self.num_heads, target_length, target_length)
attention = attention.mean(dim=1)
# (batch * num_heads, target_length, d_head)
# -> (batch, num_heads, target_length, d_head)
# -> (batch, target_length, num_heads, d_head)
# -> (batch, target_length, num_heads * d_head)
context = context.view(batch_size, self.num_heads, -1, self.d_head)
context = context.permute(0, 2, 1, 3).contiguous().view(batch_size, -1, self.num_heads * self.d_head)
context = self.W_o(context)
return context, attention
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment