Created
June 4, 2019 08:53
-
-
Save yzhangcs/d0e62407656b1060eb792359d03d4eea to your computer and use it in GitHub Desktop.
partition transformer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import torch | |
import torch.nn as nn | |
class Transformer(nn.Module): | |
def __init__(self, n_layers, n_heads, n_model, n_embed, n_inner, | |
input_dropout=0.1, attn_dropout=0.1, ffn_dropout=0.1): | |
super(Transformer, self).__init__() | |
self.layers = nn.ModuleList([Layer(n_heads, n_model, n_embed, n_inner, | |
attn_dropout, ffn_dropout) | |
for _ in range(n_layers)]) | |
self.layer_norm = nn.LayerNorm(n_model) | |
self.dropout = nn.Dropout(input_dropout) | |
def forward(self, x, mask): | |
x = torch.cat((x, self.init_pos(x)), dim=-1) | |
x = self.layer_norm(x) | |
x = self.dropout(x) | |
for layer in self.layers: | |
x = layer(x, mask) | |
return x | |
@classmethod | |
def init_pos(cls, x): | |
seq_len, n_model = x[0].shape | |
pos = x.new_tensor(range(seq_len)).unsqueeze(-1) | |
pos = pos / 10000 ** (x.new_tensor(range(n_model)) // 2 * 2 / n_model) | |
pos[:, 0::2] = pos[:, 0::2].sin() | |
pos[:, 1::2] = pos[:, 1::2].cos() | |
pos = pos.unsqueeze(0).expand_as(x) | |
return pos | |
class Layer(nn.Module): | |
def __init__(self, n_heads, n_model, n_embed, n_inner, | |
attn_dropout=0.1, ffn_dropout=0.1): | |
super(Layer, self).__init__() | |
self.attn = MultiHeadAttention(n_heads, n_model, n_embed, attn_dropout) | |
self.ffn = PosWiseFFN(n_model, n_inner, ffn_dropout) | |
def forward(self, x, mask): | |
x = self.attn(x, x, x, mask) | |
x = self.ffn(x) | |
return x | |
class ScaledDotProductAttention(nn.Module): | |
def __init__(self, scale, dropout=0.1): | |
super(ScaledDotProductAttention, self).__init__() | |
self.scale = scale | |
self.dropout = nn.Dropout(dropout) | |
self.softmax = nn.Softmax(dim=-1) | |
def forward(self, q, k, v, mask): | |
attn = (q @ k.transpose(-1, -2)) / self.scale | |
attn = attn.masked_fill_(~mask.unsqueeze(1), float('-inf')) | |
attn = self.softmax(attn) | |
attn = self.dropout(attn) | |
x = attn @ v | |
return x | |
class MultiHeadAttention(nn.Module): | |
def __init__(self, n_heads, n_model, n_embed, dropout=0.1): | |
super(MultiHeadAttention, self).__init__() | |
self.n_heads = n_heads | |
self.n_model = n_model | |
self.n_embed = n_embed | |
self.wq_c = nn.Linear(n_model//2, n_heads*n_embed//2, False) | |
self.wk_c = nn.Linear(n_model//2, n_heads*n_embed//2, False) | |
self.wv_c = nn.Linear(n_model//2, n_heads*n_embed//2, False) | |
self.wq_p = nn.Linear(n_model//2, n_heads*n_embed//2, False) | |
self.wk_p = nn.Linear(n_model//2, n_heads*n_embed//2, False) | |
self.wv_p = nn.Linear(n_model//2, n_heads*n_embed//2, False) | |
self.attn = ScaledDotProductAttention(n_embed**0.5, dropout) | |
self.layer_norm = nn.LayerNorm(n_model) | |
self.wo_c = nn.Linear(n_heads*n_embed//2, n_model//2, False) | |
self.wo_p = nn.Linear(n_heads*n_embed//2, n_model//2, False) | |
self.dropout = nn.Dropout(dropout) | |
self.reset_parameters() | |
def reset_parameters(self): | |
nn.init.orthogonal_(self.wq_c.weight) | |
nn.init.orthogonal_(self.wk_c.weight) | |
nn.init.orthogonal_(self.wv_c.weight) | |
nn.init.orthogonal_(self.wq_p.weight) | |
nn.init.orthogonal_(self.wk_p.weight) | |
nn.init.orthogonal_(self.wv_p.weight) | |
def forward(self, q, k, v, mask): | |
residual = q | |
batch_size, seq_len, _ = q.shape | |
q_c, q_p = q.chunk(2, dim=-1) | |
k_c, k_p = k.chunk(2, dim=-1) | |
v_c, v_p = v.chunk(2, dim=-1) | |
# [batch_size, seq_len, n_heads, n_embed // 2] | |
q_c = self.wq_c(q_c).view(batch_size, seq_len, self.n_heads, -1) | |
# [batch_size, seq_len, n_heads, n_embed // 2] | |
q_p = self.wq_p(q_p).view(batch_size, seq_len, self.n_heads, -1) | |
# [batch_size, seq_len, n_heads, n_embed // 2] | |
k_c = self.wk_c(k_c).view(batch_size, seq_len, self.n_heads, -1) | |
# [batch_size, seq_len, n_heads, n_embed // 2] | |
k_p = self.wk_p(k_p).view(batch_size, seq_len, self.n_heads, -1) | |
# [batch_size, seq_len, n_heads, n_embed // 2] | |
v_c = self.wv_c(v_c).view(batch_size, seq_len, self.n_heads, -1) | |
# [batch_size, seq_len, n_heads, n_embed // 2] | |
v_p = self.wv_p(v_p).view(batch_size, seq_len, self.n_heads, -1) | |
# [batch_size, seq_len, n_heads, n_embed] | |
q = torch.cat((q_c, q_p), dim=-1) | |
# [batch_size, seq_len, n_heads, n_embed] | |
k = torch.cat((k_c, k_p), dim=-1) | |
# [batch_size, seq_len, n_heads, n_embed] | |
v = torch.cat((v_c, v_p), dim=-1) | |
# [n_heads * batch_size, seq_len, n_embed] | |
q = q.permute(2, 0, 1, 3).reshape(-1, seq_len, self.n_embed) | |
# [n_heads * batch_size, seq_len, n_embed] | |
k = k.permute(2, 0, 1, 3).reshape(-1, seq_len, self.n_embed) | |
# [n_heads * batch_size, seq_len, n_embed] | |
v = v.permute(2, 0, 1, 3).reshape(-1, seq_len, self.n_embed) | |
# [n_heads * batch_size, seq_len, n_embed] | |
x = self.attn(q, k, v, mask.repeat(self.n_heads, 1)) | |
x = x.view(self.n_heads, batch_size, seq_len, self.n_embed) | |
x_c, x_p = x.chunk(2, dim=-1) | |
# [batch_size, seq_len, n_heads * n_embed // 2] | |
x_c = x_c.permute(1, 2, 0, 3).reshape(batch_size, seq_len, -1) | |
# [batch_size, seq_len, n_heads * n_embed // 2] | |
x_p = x_p.permute(1, 2, 0, 3).reshape(batch_size, seq_len, -1) | |
# [batch_size, seq_len, n_model // 2] | |
x_c = self.dropout(self.wo_c(x_c)) | |
# [batch_size, seq_len, n_model // 2] | |
x_p = self.dropout(self.wo_p(x_p)) | |
# [batch_size, seq_len, n_model] | |
x = torch.cat((x_c, x_p), dim=-1) | |
x = self.layer_norm(x + residual) | |
return x | |
class PosWiseFFN(nn.Module): | |
def __init__(self, n_model, n_inner, p=0.1): | |
super(PosWiseFFN, self).__init__() | |
self.w1_c = nn.Linear(n_model//2, n_inner//2) | |
self.w1_p = nn.Linear(n_model//2, n_inner//2) | |
self.activation = nn.ReLU() | |
self.w2_c = nn.Linear(n_inner//2, n_model//2) | |
self.w2_p = nn.Linear(n_inner//2, n_model//2) | |
self.layer_norm = nn.LayerNorm(n_model) | |
self.dropout = nn.Dropout(p) | |
self.reset_parameters() | |
def reset_parameters(self): | |
nn.init.orthogonal_(self.w1_c.weight) | |
nn.init.orthogonal_(self.w1_p.weight) | |
nn.init.orthogonal_(self.w2_c.weight) | |
nn.init.orthogonal_(self.w2_p.weight) | |
nn.init.zeros_(self.w1_c.bias) | |
nn.init.zeros_(self.w1_p.bias) | |
nn.init.zeros_(self.w2_c.bias) | |
nn.init.zeros_(self.w2_p.bias) | |
def forward(self, x): | |
residual = x | |
x_c, x_p = x.chunk(2, dim=-1) | |
x_c = self.w1_c(x_c) | |
x_p = self.w1_p(x_p) | |
x_c = self.activation(x_c) | |
x_p = self.activation(x_p) | |
x_c = self.dropout(x_c) | |
x_p = self.dropout(x_p) | |
x_c = self.w2_c(x_c) | |
x_p = self.w2_p(x_p) | |
x_c = self.dropout(x_c) | |
x_p = self.dropout(x_p) | |
x = torch.cat((x_c, x_p), dim=-1) | |
x = self.layer_norm(x + residual) | |
return x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment