Last active
January 18, 2022 19:36
-
-
Save theeluwin/5fc65304d74407e7c20bb71110bc87cf to your computer and use it in GitHub Desktop.
Well-documented Transformer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from typing import Optional | |
from math import ( | |
pi, | |
sqrt, | |
) | |
from torch import Tensor | |
from torch.nn.functional import softmax | |
__all__ = ( | |
'GELU', | |
'LayerNorm', | |
'Attention', | |
'MultiHeadedAttention', | |
'SublayerConnection', | |
'PositionWiseFeedForward', | |
'Transformer', | |
) | |
class GELU(nn.Module): | |
def forward(self, x: Tensor): | |
return 0.5 * x * (1 + torch.tanh(sqrt(2 / pi) * (x + 0.044715 * torch.pow(x, 3)))) | |
class LayerNorm(nn.Module): | |
def __init__(self, dim: int, eps: float = 1e-6): | |
super().__init__() | |
# params | |
self.dim = dim | |
self.eps = eps | |
# layers | |
self.alpha = nn.Parameter(torch.ones(dim)) | |
self.beta = nn.Parameter(torch.zeros(dim)) | |
def forward(self, x: Tensor): | |
mu = x.mean(-1, keepdim=True) | |
sigma = x.std(-1, keepdim=True) | |
return self.alpha * (x - mu) / (sigma + self.eps) + self.beta | |
class Attention(nn.Module): | |
def forward(self, | |
Q: Tensor, | |
K: Tensor, | |
V: Tensor, | |
mask: Optional[Tensor] = None, | |
dropout: Optional[nn.Module] = None | |
): | |
""" | |
Q: (b x ? x L x dim_Q) | |
K: (b x ? x L x dim_K) | |
V: (b x ? x L x dim_V) | |
?: 1 (squeezed) or h (multi-head) | |
mask: (b x ? x L x L) | |
dropout: nn.Module | |
assuming dim_Q = dim_K | |
""" | |
dim_Q = Q.size(-1) | |
# A: (b x ? x L x L) | |
A = torch.matmul(Q, K.transpose(-2, -1)) / sqrt(dim_Q) | |
# apply mask (the logit value of a padding token should be minus infinity) | |
if mask is not None: | |
A = A.masked_fill(mask == 0, -1e9) # tip: `mask is False` does not invoke broadcasting | |
# getting normalized(probability) weights through softmax (when padding token, it'll be 0) | |
# P: (b x ? x L x L) | |
P = softmax(A, dim=-1) | |
# apply dropout (with given dropout) | |
if dropout is not None: | |
P = dropout(P) | |
# (b x ? x L x L) @ (b x ? x L x dim_V) -> (b x ? x L x dim_V) | |
x = torch.matmul(P, V) | |
return x, P | |
class MultiHeadedAttention(nn.Module): | |
def __init__(self, | |
num_heads: int, | |
dim_model: int, | |
dropout_prob: float = 0.1 | |
): | |
""" | |
dim_K should be equal to dim_model / num_heads | |
we assume dim_Q = dim_K = dim_V | |
""" | |
super().__init__() | |
assert dim_model % num_heads == 0 | |
# params | |
self.dim_model = dim_model | |
self.num_heads = num_heads | |
self.dropout_prob = dropout_prob | |
# splitted dim_K | |
self.dim_K = dim_model // num_heads | |
# layers | |
self.W_Q = nn.Linear(dim_model, dim_model) | |
self.W_K = nn.Linear(dim_model, dim_model) | |
self.W_V = nn.Linear(dim_model, dim_model) | |
self.W_M = nn.Linear(dim_model, dim_model) | |
self.attention = Attention() | |
self.dropout = nn.Dropout(p=dropout_prob) | |
def forward(self, | |
Q: Tensor, | |
K: Tensor, | |
V: Tensor, | |
mask: Optional[Tensor] = None | |
): | |
b = Q.size(0) | |
# 1) Do all the linear projections in a batch from dim_model, then split into (num_heads x dim_K) | |
# [process] | |
# (1) linear(W): (b x L x dim_model) -> (b x L x dim_model) | |
# (2) view: (b x L x dim_model) -> (b x L x num_heads x dim_K) | |
# (3) transpose: (b x L x num_heads x dim_K) -> (b x num_heads x L x dim_K) | |
Q = self.W_Q(Q).view(b, -1, self.h, self.dim_K).transpose(1, 2) | |
K = self.W_K(K).view(b, -1, self.h, self.dim_K).transpose(1, 2) | |
V = self.W_V(V).view(b, -1, self.h, self.dim_K).transpose(1, 2) | |
# 2) Apply attention to the projected vectors in the batch | |
# note that attenion only cares about the last two dimensions | |
# x: (b x num_heads x L x dim_K) | |
x, _ = self.attention(Q, K, V, mask=mask, dropout=self.dropout) | |
# 3) "concat" those heads using view | |
# [process] | |
# (1) transpose: (b x num_heads x L x dim_K) -> (b x L x num_heads x dim_K) | |
# (2) contiguous: reorder memory inside GPU (no dimension change) | |
# (3) view: (b x L x num_heads x dim_K) -> (b x L x dim_model) | |
x = x.transpose(1, 2).contiguous().view(b, -1, self.dim_model) | |
# 4) apply the final linear | |
# x: (b x L x dim_model) | |
x = self.W_M(x) | |
return x | |
class SublayerConnection(nn.Module): | |
def __init__(self, dim: int = 256, dropout_prob: float = 0.1): | |
super().__init__() | |
# params | |
self.dim = dim | |
self.dropout_prob = dropout_prob | |
# layers | |
self.layernorm = LayerNorm(dim) | |
self.dropout = nn.Dropout(p=dropout_prob) | |
def forward(self, x: Tensor, sublayer: nn.Module): | |
r = self.layernorm(x) | |
r = sublayer(r) | |
r = self.dropout(r) | |
return x + r | |
class PositionWiseFeedForward(nn.Module): | |
def __init__(self, | |
dim_model: int = 256, | |
dim_ff: int = 1024, | |
dropout_prob: float = 0.1 | |
): | |
super().__init__() | |
# params | |
self.dim_model = dim_model | |
self.dim_ff = dim_ff | |
self.dropout_prob = dropout_prob | |
# layers | |
self.W_1 = nn.Linear(dim_model, dim_ff) | |
self.W_2 = nn.Linear(dim_ff, dim_model) | |
self.dropout = nn.Dropout(p=dropout_prob) | |
self.gelu = GELU() | |
def forward(self, x: Tensor): | |
x = self.W_1(x) # (b x dim_model) -> (b x dim_ff) | |
x = self.gelu(x) | |
x = self.dropout(x) | |
x = self.W_2(x) # (b x dim_ff) -> (b x dim_model) | |
return x | |
class Transformer(nn.Module): | |
def __init__(self, | |
dim_model: int = 256, | |
num_heads: int = 4, | |
dim_ff: int = 1024, | |
dropout_prob: float = 0.1 | |
): | |
super().__init__() | |
# params | |
self.dim_model = dim_model | |
self.num_heads = num_heads | |
self.dim_ff = dim_ff | |
self.dropout_prob = dropout_prob | |
# layers | |
self.attention = MultiHeadedAttention(num_heads=num_heads, dim_model=dim_model, dropout_prob=dropout_prob) | |
self.attention_sublayer = SublayerConnection(dim=dim_model, dropout_prob=dropout_prob) | |
self.pwff = PositionWiseFeedForward(dim_model=dim_model, dim_ff=dim_ff, dropout_prob=dropout_prob) | |
self.pwff_sublayer = SublayerConnection(dim=dim_model, dropout_prob=dropout_prob) | |
self.dropout = nn.Dropout(p=dropout_prob) | |
def forward(self, x: Tensor, mask: Optional[Tensor] = None): | |
# we need dynamic mask for the attention forward (sublayer module also has parameters, namely layernorm) | |
# x: (b x L x dim_model) | |
# mask: (b x L x L), set False to ignore that point | |
x = self.attention_sublayer(x, lambda z: self.attention.forward(z, z, z, mask=mask)) | |
x = self.pwff_sublayer(x, self.pwff) | |
x = self.dropout(x) | |
return x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment