-
current models have trouble learning dependencies over distance (i.e. between characters/words), # ops scale O(n) or O(log n).
-
transformer is O(1) in number of ops
-
encoder-decoder with residual conns. Encoder/decodes feed into themselves N times.
-
We also modify the self-attention sub-layer in the decoder stack to prevent positions from attending to subsequent positions. This masking, combined with fact that the output embeddings are offset by one position, **ensures that the predictions for position i can depend only on the known outputs at positions less than i **.
def subsequent_mask(size):
"Mask out subsequent positions."
attn_shape = (1, size, size)
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
return torch.from_numpy(subsequent_mask) == 0
with size
10 returns:
tensor([[[ 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[ 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[ 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
[ 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
[ 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[ 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
[ 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
[ 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
[ 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
[ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
represented as vectors, but packed into matrices:
query (Q)
keys (K)
values (V)
scaled dot-product attention, sqrt(d_k)
scales the dot product otherwise the result
would have mean 0 with variance d_k
, potentially leading to very small gradients.
def attention(query, key, value, mask=None, dropout=None):
"Compute 'Scaled Dot Product Attention'"
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
p_attn = F.softmax(scores, dim = -1)
if dropout is not None:
p_attn = dropout(p_attn)
return torch.matmul(p_attn, value), p_attn
multi-head attention:
head_i = attention(Q * W_i^{Q}, K * W_i^{K}, V * W_i^{V})
multihead_attention = concat(head_0, ..., head_{n-1}) * W_o
matrix size for the multi-heads should be proportional to the # of heads and model size s.t. computational cost remains within a small constant of single-head attention