Skip to content

Instantly share code, notes, and snippets.

@yzhangcs
Created June 4, 2019 08:53
Show Gist options
  • Save yzhangcs/d0e62407656b1060eb792359d03d4eea to your computer and use it in GitHub Desktop.
Save yzhangcs/d0e62407656b1060eb792359d03d4eea to your computer and use it in GitHub Desktop.
partition transformer
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
class Transformer(nn.Module):
def __init__(self, n_layers, n_heads, n_model, n_embed, n_inner,
input_dropout=0.1, attn_dropout=0.1, ffn_dropout=0.1):
super(Transformer, self).__init__()
self.layers = nn.ModuleList([Layer(n_heads, n_model, n_embed, n_inner,
attn_dropout, ffn_dropout)
for _ in range(n_layers)])
self.layer_norm = nn.LayerNorm(n_model)
self.dropout = nn.Dropout(input_dropout)
def forward(self, x, mask):
x = torch.cat((x, self.init_pos(x)), dim=-1)
x = self.layer_norm(x)
x = self.dropout(x)
for layer in self.layers:
x = layer(x, mask)
return x
@classmethod
def init_pos(cls, x):
seq_len, n_model = x[0].shape
pos = x.new_tensor(range(seq_len)).unsqueeze(-1)
pos = pos / 10000 ** (x.new_tensor(range(n_model)) // 2 * 2 / n_model)
pos[:, 0::2] = pos[:, 0::2].sin()
pos[:, 1::2] = pos[:, 1::2].cos()
pos = pos.unsqueeze(0).expand_as(x)
return pos
class Layer(nn.Module):
def __init__(self, n_heads, n_model, n_embed, n_inner,
attn_dropout=0.1, ffn_dropout=0.1):
super(Layer, self).__init__()
self.attn = MultiHeadAttention(n_heads, n_model, n_embed, attn_dropout)
self.ffn = PosWiseFFN(n_model, n_inner, ffn_dropout)
def forward(self, x, mask):
x = self.attn(x, x, x, mask)
x = self.ffn(x)
return x
class ScaledDotProductAttention(nn.Module):
def __init__(self, scale, dropout=0.1):
super(ScaledDotProductAttention, self).__init__()
self.scale = scale
self.dropout = nn.Dropout(dropout)
self.softmax = nn.Softmax(dim=-1)
def forward(self, q, k, v, mask):
attn = (q @ k.transpose(-1, -2)) / self.scale
attn = attn.masked_fill_(~mask.unsqueeze(1), float('-inf'))
attn = self.softmax(attn)
attn = self.dropout(attn)
x = attn @ v
return x
class MultiHeadAttention(nn.Module):
def __init__(self, n_heads, n_model, n_embed, dropout=0.1):
super(MultiHeadAttention, self).__init__()
self.n_heads = n_heads
self.n_model = n_model
self.n_embed = n_embed
self.wq_c = nn.Linear(n_model//2, n_heads*n_embed//2, False)
self.wk_c = nn.Linear(n_model//2, n_heads*n_embed//2, False)
self.wv_c = nn.Linear(n_model//2, n_heads*n_embed//2, False)
self.wq_p = nn.Linear(n_model//2, n_heads*n_embed//2, False)
self.wk_p = nn.Linear(n_model//2, n_heads*n_embed//2, False)
self.wv_p = nn.Linear(n_model//2, n_heads*n_embed//2, False)
self.attn = ScaledDotProductAttention(n_embed**0.5, dropout)
self.layer_norm = nn.LayerNorm(n_model)
self.wo_c = nn.Linear(n_heads*n_embed//2, n_model//2, False)
self.wo_p = nn.Linear(n_heads*n_embed//2, n_model//2, False)
self.dropout = nn.Dropout(dropout)
self.reset_parameters()
def reset_parameters(self):
nn.init.orthogonal_(self.wq_c.weight)
nn.init.orthogonal_(self.wk_c.weight)
nn.init.orthogonal_(self.wv_c.weight)
nn.init.orthogonal_(self.wq_p.weight)
nn.init.orthogonal_(self.wk_p.weight)
nn.init.orthogonal_(self.wv_p.weight)
def forward(self, q, k, v, mask):
residual = q
batch_size, seq_len, _ = q.shape
q_c, q_p = q.chunk(2, dim=-1)
k_c, k_p = k.chunk(2, dim=-1)
v_c, v_p = v.chunk(2, dim=-1)
# [batch_size, seq_len, n_heads, n_embed // 2]
q_c = self.wq_c(q_c).view(batch_size, seq_len, self.n_heads, -1)
# [batch_size, seq_len, n_heads, n_embed // 2]
q_p = self.wq_p(q_p).view(batch_size, seq_len, self.n_heads, -1)
# [batch_size, seq_len, n_heads, n_embed // 2]
k_c = self.wk_c(k_c).view(batch_size, seq_len, self.n_heads, -1)
# [batch_size, seq_len, n_heads, n_embed // 2]
k_p = self.wk_p(k_p).view(batch_size, seq_len, self.n_heads, -1)
# [batch_size, seq_len, n_heads, n_embed // 2]
v_c = self.wv_c(v_c).view(batch_size, seq_len, self.n_heads, -1)
# [batch_size, seq_len, n_heads, n_embed // 2]
v_p = self.wv_p(v_p).view(batch_size, seq_len, self.n_heads, -1)
# [batch_size, seq_len, n_heads, n_embed]
q = torch.cat((q_c, q_p), dim=-1)
# [batch_size, seq_len, n_heads, n_embed]
k = torch.cat((k_c, k_p), dim=-1)
# [batch_size, seq_len, n_heads, n_embed]
v = torch.cat((v_c, v_p), dim=-1)
# [n_heads * batch_size, seq_len, n_embed]
q = q.permute(2, 0, 1, 3).reshape(-1, seq_len, self.n_embed)
# [n_heads * batch_size, seq_len, n_embed]
k = k.permute(2, 0, 1, 3).reshape(-1, seq_len, self.n_embed)
# [n_heads * batch_size, seq_len, n_embed]
v = v.permute(2, 0, 1, 3).reshape(-1, seq_len, self.n_embed)
# [n_heads * batch_size, seq_len, n_embed]
x = self.attn(q, k, v, mask.repeat(self.n_heads, 1))
x = x.view(self.n_heads, batch_size, seq_len, self.n_embed)
x_c, x_p = x.chunk(2, dim=-1)
# [batch_size, seq_len, n_heads * n_embed // 2]
x_c = x_c.permute(1, 2, 0, 3).reshape(batch_size, seq_len, -1)
# [batch_size, seq_len, n_heads * n_embed // 2]
x_p = x_p.permute(1, 2, 0, 3).reshape(batch_size, seq_len, -1)
# [batch_size, seq_len, n_model // 2]
x_c = self.dropout(self.wo_c(x_c))
# [batch_size, seq_len, n_model // 2]
x_p = self.dropout(self.wo_p(x_p))
# [batch_size, seq_len, n_model]
x = torch.cat((x_c, x_p), dim=-1)
x = self.layer_norm(x + residual)
return x
class PosWiseFFN(nn.Module):
def __init__(self, n_model, n_inner, p=0.1):
super(PosWiseFFN, self).__init__()
self.w1_c = nn.Linear(n_model//2, n_inner//2)
self.w1_p = nn.Linear(n_model//2, n_inner//2)
self.activation = nn.ReLU()
self.w2_c = nn.Linear(n_inner//2, n_model//2)
self.w2_p = nn.Linear(n_inner//2, n_model//2)
self.layer_norm = nn.LayerNorm(n_model)
self.dropout = nn.Dropout(p)
self.reset_parameters()
def reset_parameters(self):
nn.init.orthogonal_(self.w1_c.weight)
nn.init.orthogonal_(self.w1_p.weight)
nn.init.orthogonal_(self.w2_c.weight)
nn.init.orthogonal_(self.w2_p.weight)
nn.init.zeros_(self.w1_c.bias)
nn.init.zeros_(self.w1_p.bias)
nn.init.zeros_(self.w2_c.bias)
nn.init.zeros_(self.w2_p.bias)
def forward(self, x):
residual = x
x_c, x_p = x.chunk(2, dim=-1)
x_c = self.w1_c(x_c)
x_p = self.w1_p(x_p)
x_c = self.activation(x_c)
x_p = self.activation(x_p)
x_c = self.dropout(x_c)
x_p = self.dropout(x_p)
x_c = self.w2_c(x_c)
x_p = self.w2_p(x_p)
x_c = self.dropout(x_c)
x_p = self.dropout(x_p)
x = torch.cat((x_c, x_p), dim=-1)
x = self.layer_norm(x + residual)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment