Last active
February 25, 2022 11:37
-
-
Save hengck23/d3eb40d9b5bae7d08d3e12f26c84b0d7 to your computer and use it in GitHub Desktop.
transformer fast decoder
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from common import * | |
# https://scale.com/blog/pytorch-improvements | |
# Making Pytorch Transformer Twice as Fast on Sequence Generation. | |
# https://towardsdatascience.com/how-to-code-the-transformer-in-pytorch-24db27c8f9ec | |
class FeedForward(nn.Module): | |
def __init__(self, dim, ff_dim=2048, dropout=0.1): | |
super().__init__() | |
self.linear1 = nn.Linear(dim, ff_dim) | |
self.linear2 = nn.Linear(ff_dim, dim) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x): | |
x = self.dropout(F.relu(self.linear1(x))) | |
x = self.linear2(x) | |
return x | |
#layer normalization | |
class Norm(nn.Module): | |
def __init__(self, dim, eps=1e-6): | |
super().__init__() | |
self.alpha = nn.Parameter(torch.ones(dim)) | |
self.bias = nn.Parameter(torch.zeros(dim)) | |
self.eps = eps | |
def forward(self, x): | |
#return x | |
z = (x - x.mean(dim=-1, keepdim=True)) / (x.std(dim=-1, keepdim=True) + self.eps) | |
x = self.alpha*z + self.bias | |
return x | |
class MultiHeadAttention(nn.Module): | |
def __init__(self, dim, num_head, dropout=0.1): | |
super().__init__() | |
self.dim = dim | |
self.d_k = dim // num_head | |
self.num_head = num_head | |
self.dropout = dropout | |
self.q = nn.Linear(dim, dim) | |
self.v = nn.Linear(dim, dim) | |
self.k = nn.Linear(dim, dim) | |
self.out = nn.Linear(dim, dim) | |
def attention(self, q, k, v, mask): | |
score = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) # torch.Size([8, 4, 10, 10]) = batch_size, num_head, LqxLk | |
score = torch.clamp_min(score,-6e3) | |
if mask is not None: | |
mask = mask.unsqueeze(1) | |
#print(score.min()) | |
#score = score.masked_fill(mask == 0, -6e6) #-65504 | |
score = score.masked_fill(mask == 0, -6e4) #-65504 | |
#score = score.masked_fill(mask == 0, -half('inf')) | |
# https://github.com/NVIDIA/apex/issues/93 | |
# How to use fp16 training with masked operations | |
score = F.softmax(score, dim=-1) | |
if self.dropout > 0: | |
score = F.dropout(score, self.dropout, training=self.training) | |
value = torch.matmul(score, v) | |
return value | |
def forward(self, q, k, v, mask=None): | |
batch_size, T, dim = q.shape | |
# perform linear operation and split into h heads | |
k = self.k(k).reshape(batch_size, -1, self.num_head, self.d_k) | |
q = self.q(q).reshape(batch_size, -1, self.num_head, self.d_k) | |
v = self.v(v).reshape(batch_size, -1, self.num_head, self.d_k) | |
# transpose to get dimensions batch_size * num_head * T * d_k | |
k = k.transpose(1, 2) | |
q = q.transpose(1, 2) | |
v = v.transpose(1, 2) | |
# calculate attention using function we will define next | |
value = self.attention(q, k, v, mask) | |
# concatenate heads and put through final linear layer | |
value = value.transpose(1, 2).contiguous().reshape(batch_size, -1, self.dim) | |
value = self.out(value) | |
return value | |
#--- | |
class TransformerEncodeLayer(nn.Module): | |
def __init__(self, dim, ff_dim, num_head, dropout=0.1): | |
super().__init__() | |
self.norm1 = Norm(dim) | |
self.norm2 = Norm(dim) | |
self.attn = MultiHeadAttention(dim, num_head, dropout=0.1) | |
self.ff = FeedForward(dim, ff_dim) | |
self.dropout1 = nn.Dropout(dropout) | |
self.dropout2 = nn.Dropout(dropout) | |
def forward(self, x, x_mask): | |
x1 = self.attn(x, x1, x1, x_mask) #self-attention | |
x1 = x + self.dropout1(x1) | |
x = self.norm1(x1) | |
x2 = self.ff(x) | |
x2 = x + self.dropout2(x2) | |
x = self.norm2(x2) | |
return x | |
class TransformerEncode(nn.Module): | |
def __init__(self, dim, ff_dim, num_head, num_layer): | |
super().__init__() | |
self.num_layer = num_layer | |
self.layer = nn.ModuleList([ | |
TransformerEncodeLayer(dim, ff_dim, num_head) for i in range(num_layer) | |
]) | |
self.norm = Norm(dim) | |
def forward(self, x, x_mask): | |
for i in range(self.num_layer): | |
x = self.layer[i](x, x_mask) | |
return x | |
#--- | |
class TransformerDecodeLayer(nn.Module): | |
def __init__(self, dim, ff_dim, num_head, dropout=0.1): | |
super().__init__() | |
self.norm1 = Norm(dim) | |
self.norm2 = Norm(dim) | |
self.norm3 = Norm(dim) | |
self.attn1 = MultiHeadAttention(dim, num_head, dropout=0.1) | |
self.attn2 = MultiHeadAttention(dim, num_head, dropout=0.1) | |
self.ff = FeedForward(dim, ff_dim) | |
self.dropout1 = nn.Dropout(dropout) | |
self.dropout2 = nn.Dropout(dropout) | |
self.dropout3 = nn.Dropout(dropout) | |
def forward(self, x, mem, x_mask, mem_mask): | |
x1 = self.attn1(x, x, x, x_mask) # self-attention | |
x1 = x + self.dropout1(x1) | |
x = self.norm1(x1) | |
if mem is not None: | |
x2 = self.attn2(x, mem, mem, mem_mask) # encoder input | |
x2 = x + self.dropout2(x2) | |
x = self.norm2(x2) | |
x3 = self.ff(x) | |
x3 = x + self.dropout3(x3) | |
x = self.norm3(x3) | |
return x | |
def forward_last_one(self, x, mem, mem_mask): | |
x_one = x[:, [-1]] | |
x1 = self.attn1(x_one, x, x) # self-attention | |
x_one = x_one + x1 | |
x_one = self.norm1(x_one) | |
if mem is not None: | |
x2 = self.attn2(x_one, mem, mem, mem_mask) # encoder input | |
x_one = x_one + x2 | |
x_one = self.norm2(x_one) | |
x3 = self.ff(x_one) | |
x_one = x_one + x3 | |
x_one = self.norm3(x_one) | |
return x_one | |
# ------------------------------------------------------ | |
# https://kazemnejad.com/blog/transformer_architecture_positional_encoding/ | |
# https://stackoverflow.com/questions/46452020/sinusoidal-embedding-attention-is-all-you-need | |
# class PositionEncode1D(nn.Module): | |
# def __init__(self, dim, length): | |
# super().__init__() | |
# | |
# def forward(self, x): | |
# return x | |
class PositionEncode1D(nn.Module): | |
def __init__(self, dim, max_length): | |
super().__init__() | |
assert (dim % 2 == 0) | |
self.max_length = max_length | |
d = torch.exp(torch.arange(0., dim, 2)* (-math.log(10000.0) / dim)) | |
position = torch.arange(0., max_length).unsqueeze(1) | |
pos = torch.zeros(1, max_length, dim) | |
pos[0, :, 0::2] = torch.sin(position * d) | |
pos[0, :, 1::2] = torch.cos(position * d) | |
self.register_buffer('pos', pos) | |
def forward(self, x): | |
batch_size, T, dim = x.shape | |
x = x + self.pos[:,:T] | |
return x | |
#https://github.com/wzlxjtu/PositionalEncoding2D/blob/master/positionalembedding2d.py | |
class PositionEncode2D(nn.Module): | |
def __init__(self, dim, width, height): | |
super().__init__() | |
assert (dim % 4 == 0) | |
self.width = width | |
self.height = height | |
dim = dim//2 | |
d = torch.exp(torch.arange(0., dim, 2) * -(math.log(10000.0) / dim)) | |
position_w = torch.arange(0., width ).unsqueeze(1) | |
position_h = torch.arange(0., height).unsqueeze(1) | |
pos = torch.zeros(1, dim*2, height, width) | |
pos[0, 0:dim:2, :, :] = torch.sin(position_w * d).transpose(0, 1).unsqueeze(1).repeat(1,1, height, 1) | |
pos[0, 1:dim:2, :, :] = torch.cos(position_w * d).transpose(0, 1).unsqueeze(1).repeat(1,1, height, 1) | |
pos[0,dim + 0: :2, :, :] = torch.sin(position_h * d).transpose(0, 1).unsqueeze(2).repeat(1,1, 1, width) | |
pos[0,dim + 1: :2, :, :] = torch.cos(position_h * d).transpose(0, 1).unsqueeze(2).repeat(1,1, 1, width) | |
self.register_buffer('pos', pos) | |
def forward(self, x): | |
batch_size,C,H,W = x.shape | |
x = x + self.pos[:,:,:H,:W] | |
return x | |
# pos = PositionEncode(dim=128) | |
# relative_time = torch.rand(10,5,1) | |
# pos(relative_time) | |
# exit(0) | |
# ------------------------------------ | |
''' | |
mask | |
array([[[0, 1, 1, 1, 1, 1, 1, 1, 1, 1], | |
[0, 0, 1, 1, 1, 1, 1, 1, 1, 1], | |
[0, 0, 0, 1, 1, 1, 1, 1, 1, 1], | |
[0, 0, 0, 0, 1, 1, 1, 1, 1, 1], | |
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1], | |
[0, 0, 0, 0, 0, 0, 1, 1, 1, 1], | |
[0, 0, 0, 0, 0, 0, 0, 1, 1, 1], | |
[0, 0, 0, 0, 0, 0, 0, 0, 1, 1], | |
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1], | |
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]], dtype=uint8) | |
''' | |
# def triangle_mask(size): | |
# mask = np.triu(np.ones((1, size, size)),k=1).astype('uint8') | |
# mask = torch.autograd.Variable(torch.from_numpy(mask) == 0) | |
# return mask | |
# #triangle_mask(10) | |
# | |
# https://github.com/alexmt-scale/causal-transformer-decoder/blob/master/causal_transformer_decoder/model.py | |
class TransformerDecode(nn.Module): | |
def __init__(self, dim, ff_dim, num_head, num_layer): | |
super().__init__() | |
self.num_layer = num_layer | |
self.layer = nn.ModuleList([ | |
TransformerDecodeLayer(dim, ff_dim, num_head) for i in range(num_layer) | |
]) | |
self.norm = Norm(dim) | |
def forward(self, x, mem, x_mask=None, mem_mask=None): | |
for i in range(self.num_layer): | |
x = self.layer[i](x, mem, x_mask, mem_mask) | |
return x | |
def forward_last_one(self, x, mem, mem_mask=None, cache=None): | |
xx = [] | |
for i in range(self.num_layer): | |
x = self.layer[i].forward_last_one(x, mem, mem_mask) | |
xx.append(x) | |
if cache is not None: | |
x = torch.cat([cache[i], x], dim=1) | |
if cache is not None: | |
new_cache = torch.cat([cache, torch.stack(xx, dim=0)], dim=2) | |
else: | |
new_cache = torch.stack(xx, dim=0) #num_layer, batch_size,length,dim | |
return x, new_cache | |
# check ################################################################ | |
# https://github.com/alexmt-scale/causal-transformer-decoder/blob/master/tests/test_consistency.py | |
def run_check_fast_decode(): | |
batch_size = 2 | |
length=6 | |
dim = 4 | |
num_head = 2 | |
ff_dim = dim * num_head | |
num_layer = 1 | |
decoder = TransformerDecode(dim, ff_dim, num_head, num_layer) | |
decoder.eval() | |
#---- | |
mem = torch.rand(batch_size, 5, dim) | |
first_x = torch.rand(batch_size, 1, dim) | |
#---- | |
x1 = first_x | |
for t in range(length - 1): | |
# create mask for autoregressive decoding | |
mask = 1 - np.triu(np.ones((batch_size, (t+1), (t+1))), k=1).astype(np.uint8) | |
mask = torch.autograd.Variable(torch.from_numpy(mask)) | |
y = decoder( x1, mem, x_mask=mask ) | |
x1 = torch.cat( [x1, y[:,-1:]], dim=1) | |
print(x1) | |
print(x1.shape) | |
#---- | |
cache = None | |
x2 = first_x | |
for t in range(length - 1): | |
y, cache = decoder.forward_last_one( x2, mem, cache=cache ) | |
x2 = torch.cat( [x2, y[:,-1:]], dim=1) | |
print(x2) | |
print(x2.shape) | |
print(torch.eq(x1, x2)) | |
diff = torch.abs(x1-x2) | |
print(diff) | |
print(diff.max(),diff.min()) | |
# main ################################################################# | |
if __name__ == '__main__': | |
run_check_fast_decode() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment