Created July 23, 2020 06:09
T5 relative positional embedding
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
class RelativePositionBias(nn.Module):
def __init__(self, bidirectional=True, num_buckets=32, max_distance=128, n_heads=2):
super(RelativePositionBias, self).__init__()
self.bidirectional = bidirectional
self.num_buckets = num_buckets
self.max_distance = max_distance
self.n_heads = n_heads
self.relative_attention_bias = nn.Embedding(self.num_buckets, self.n_heads)
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
Adapted from Mesh Tensorflow:
Translate relative position to a bucket number for relative attention.
The relative position is defined as memory_position - query_position, i.e.
the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are
We use smaller buckets for small absolute relative_position and larger buckets
for larger absolute relative_positions. All relative positions >=max_distance
map to the same bucket. All relative positions <=-max_distance map to the
same bucket. This should allow for more graceful generalization to longer
sequences than the model has been trained on.
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
a Tensor with the same shape as relative_position, containing int32
values in the range [0, num_buckets)
ret = 0
n = -relative_position
if bidirectional:
num_buckets //= 2
ret += (n < 0).to(torch.long) * num_buckets # mtf.to_int32(mtf.less(n, 0)) * num_buckets
n = torch.abs(n)
n = torch.max(n, torch.zeros_like(n))
# now n is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = n < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
ret += torch.where(is_small, n, val_if_large)
return ret
def compute_bias(self, qlen, klen):
""" Compute binned relative position bias """
context_position = torch.arange(qlen, dtype=torch.long,
device=self.relative_attention_bias.weight.device)[:, None]
memory_position = torch.arange(klen, dtype=torch.long,
device=self.relative_attention_bias.weight.device)[None, :]
relative_position = memory_position - context_position # shape (qlen, klen)
0 1 2 3
q -1 0 1 2
-2 -1 0 1
-3 -2 -1 0
rp_bucket = self._relative_position_bucket(
relative_position, # shape (qlen, klen)
rp_bucket =
values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen)
return values
def forward(self, qlen, klen):
return self.compute_bias(qlen, klen) # shape (1, num_heads, qlen, klen)
# torch>=1.5.0 F.multi_head_attention_forward
if attn_mask is not None:
# add relative positional embedding to atttn_mask # shape (N*num_heads, L, S)
attn_output_weights += attn_mask
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
attn_output_weights = attn_output_weights.masked_fill(
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
attn_output_weights = softmax(
attn_output_weights, dim=-1)
attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training)
attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
It's missing the max_distance in the call to _relative_position_bucket (:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment