Created
November 1, 2018 10:25
-
-
Save codekansas/96f7947fd4dfd3672e227bef9fab1988 to your computer and use it in GitHub Desktop.
Implementation of the transformer block used by BERT
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
"""Implementation of the transformer block used by BERT. | |
I saw an excellent implementation of the complete BERT model here: | |
https://github.com/codertimo/BERT-pytorch | |
I re-wrote a simplified version of the transformer block below. This was mainly | |
for my own understanding (so that I could get a grasp of the dimensions and | |
how the whole attention mechanism works), but I tried to document it pretty | |
thoroughly so that other people can understand it without having to go too far | |
into the weeds. The training task at the bottom is just a proof-of-concept, | |
where the model learns to output the input sequence. | |
""" | |
import math | |
import torch | |
from torch import ( | |
nn, | |
optim, | |
Tensor, | |
) | |
from torch.nn import functional as F | |
class GELU(nn.Module): | |
"""Defines the Gaussian Error Linear Unit (GELU) activation function. | |
Input: | |
float tensor of any shape | |
Output: | |
float tensor with the same shape. | |
""" | |
def forward(self, x: Tensor) -> Tensor: | |
a = math.sqrt(2 / math.pi) | |
b = 0.044715 | |
return 0.5 * x * (1 + torch.tanh(a * (x + b * torch.pow(x, 3)))) | |
class TwoLayerLinear(nn.Module): | |
"""Defines a module with two linear layers, with dropout. | |
Args: | |
num_input: int, number of input dimensions. | |
num_hidden: int, number of hidden dimensons (between first and second). | |
dropout: float, the dropout rate. | |
Input: | |
x: float, (batch_size, time_steps, num_input) | |
Output: | |
float, (batch_size, time_steps, num_input) | |
""" | |
def __init__(self, | |
num_input: int, | |
num_hidden: int, | |
dropout: float=0.1) -> None: | |
super(TwoLayerLinear, self).__init__() | |
self.num_input = num_input | |
self.num_hidden = num_hidden | |
self.dropout = dropout | |
self.first_layer = nn.Linear(num_input, num_hidden) | |
self.second_layer = nn.Linear(num_hidden, num_input) | |
self.dropout_layer = nn.Dropout(dropout) | |
self.activation = GELU() | |
def forward(self, x: Tensor) -> Tensor: | |
x = self.first_layer(x) | |
x = self.activation(x) | |
x = self.dropout_layer(x) | |
x = self.second_layer(x) | |
return x | |
class Encoder(nn.Module): | |
"""Defines a general attention encoder. | |
Params: | |
num_input: int, number of input dimensions. | |
num_heads: int, number of attention heads to encode. | |
num_dimensions: int, number of encoding dimensions. | |
Input: | |
x: float, (batch_size, time_steps, num_input) | |
Output: | |
float, (batch_size, num_heads, time_steps, num_dimensions) | |
""" | |
def __init__(self, | |
num_input: int, | |
num_heads: int, | |
num_dimensions: int) -> None: | |
super(Encoder, self).__init__() | |
self.num_input = num_input | |
self.num_heads = num_heads | |
self.num_dimensions = num_dimensions | |
self.layer = nn.Linear(num_input, num_heads * num_dimensions) | |
def forward(self, x: Tensor) -> Tensor: | |
batch_size, time_steps, _ = x.size() | |
shape = (batch_size, time_steps, self.num_heads, self.num_dimensions) | |
x = self.layer(x).view(*shape) | |
return x.transpose(1, 2) | |
class AttentionLayer(nn.Module): | |
"""Defines a multi-headed scaled dot-product attention model. | |
Params: | |
num_input: int, number of input dimensions. | |
num_heads: int, number of attention heads to use. | |
num_hidden: int, number of hidden dimensions. | |
num_key_dims: int, number of dimensions in the key encoder. Defaults to | |
num_hidden. | |
num_value_dims: int, number of dimensions in the value encoder. Defaults | |
to num_hidden. | |
dropout: float, dropout to apply to attention weights. | |
Input: | |
x: float, (batch_size, time_steps, num_input) | |
mask: byte, (batch_size, time_steps) | |
Output: | |
float, (batch_size, time_steps, num_hidden) | |
""" | |
def __init__(self, | |
num_input: int, | |
num_heads: int, | |
num_hidden: int, | |
num_key_dims: int=None, | |
num_value_dims: int=None, | |
dropout: float=0.1) -> None: | |
super(AttentionLayer, self).__init__() | |
num_key_dims = num_key_dims or num_hidden | |
num_value_dims = num_value_dims or num_hidden | |
self.num_input = num_input | |
self.num_heads = num_heads | |
self.num_hidden = num_hidden | |
self.num_key_dims = num_key_dims | |
self.num_value_dims = num_value_dims | |
self.dropout = dropout | |
self.scale = math.sqrt(num_key_dims) | |
self.query_layer = Encoder(num_input, num_heads, num_key_dims) | |
self.key_layer = Encoder(num_input, num_heads, num_key_dims) | |
self.value_layer = Encoder(num_input, num_heads, num_value_dims) | |
self.dropout_layer = nn.Dropout(dropout) | |
self.decoder_layer = nn.Linear(num_heads * num_value_dims, num_hidden) | |
def forward(self, x: Tensor, mask: Tensor=None) -> Tensor: | |
batch_size, time_steps, _ = x.size() | |
# (batch_size, num_heads, time_steps, num_key_dims) | |
query = self.query_layer(x) | |
# (batch_size, num_heads, time_steps, num_key_dims) | |
key = self.key_layer(x) | |
# (batch_size, num_heads, time_steps, num_value_dims) | |
value = self.value_layer(x) | |
# (batch_size, num_heads, time_steps, time_steps) | |
logits = torch.matmul(query, key.transpose(-1, -2)) / self.scale | |
if mask is not None: | |
# (batch_size, 1, time_steps, time_steps) | |
mask = mask.unsqueeze(1).repeat(1, time_steps, 1).unsqueeze(1) | |
logits = logits.masked_fill(mask, -1e9) | |
softmax_weights = self.dropout_layer(F.softmax(logits, dim=-1)) | |
# (batch_size, num_heads, time_steps, num_value_dims) | |
values = torch.matmul(softmax_weights, value) | |
# (batch_size, time_steps, num_heads * num_value_dims) | |
values = values.transpose(1, 2).contiguous().view( | |
batch_size, | |
time_steps, | |
self.num_heads * self.num_value_dims, | |
) | |
return self.decoder_layer(values) | |
class LayerNorm(nn.Module): | |
"""Defines a layer normalization layer. | |
See "Layer Normalization" (Ba et. al., 2016) for more details. | |
Args: | |
features: int, the number of input features. | |
eps: float, epsilon parameter (to avoid divide by zero). | |
Input: | |
float, (..., features) | |
Output: | |
float, (..., features) | |
""" | |
def __init__(self, features: int, eps: float=1e-6) -> None: | |
super().__init__() | |
self.gamma = nn.Parameter(torch.ones(features)) | |
self.beta = nn.Parameter(torch.zeros(features)) | |
self.eps = eps | |
def forward(self, x: Tensor) -> Tensor: | |
mean = x.mean(-1, keepdim=True) | |
std = x.std(-1, keepdim=True) | |
return self.gamma * (x - mean) / (std + self.eps) + self.beta | |
class Transformer(nn.Module): | |
"""Defines the transformer used by BERT. | |
This transformer looks at the left and right context for a word to try to | |
disambiguate it's meaning. It can be used for a variety of NLP tasks. | |
Input: | |
x: float, (batch_size, time_steps, num_hidden) | |
mask: byte, (batch_size, time_steps), where masked dims are nonzero. | |
Output: | |
float, (batch_size, time_steps, num_hidden) | |
""" | |
def __init__(self, | |
num_hidden: int, | |
num_heads: int, | |
num_linear_hidden: int, | |
dropout: float=0.1) -> None: | |
super(Transformer, self).__init__() | |
self.num_hidden = num_hidden | |
self.num_heads = num_heads | |
self.num_linear_hidden = num_linear_hidden | |
self.dropout = dropout | |
self.attention = AttentionLayer(num_hidden, num_heads, num_hidden) | |
self.attention_norm = LayerNorm(num_hidden) | |
self.linear = TwoLayerLinear(num_hidden, num_linear_hidden) | |
self.linear_norm = LayerNorm(num_hidden) | |
self.dropout_layer = nn.Dropout(dropout) | |
def forward(self, x: Tensor, mask: Tensor=None) -> Tensor: | |
x = self.dropout_layer(self.attention(self.attention_norm(x), mask)) + x | |
x = self.dropout_layer(self.linear(self.linear_norm(x))) + x | |
return x | |
if __name__ == '__main__': | |
model = Transformer(20, 3, 80) | |
optimizer = optim.Adam(model.parameters()) | |
loss_function = nn.L1Loss() | |
for _ in range(1000): | |
input = torch.randn(128, 10, 20) | |
optimizer.zero_grad() | |
output = model(input) | |
loss = loss_function(output, input) | |
print(loss) | |
loss.backward() | |
optimizer.step() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment