Last active
January 7, 2020 18:11
-
-
Save lucidrains/cb09ec40199900d20685a40804b23bc0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# LSH attention as described in https://openreview.net/pdf?id=rkgNKkHtvB | |
# adapted from trax, stripped to what paper said needed to work | |
# namely that buckets need to be at least 64 with 8 rounds of hashing | |
# https://github.com/google/trax/blob/master/trax/layers/research/efficient_attention.py#L442 | |
from torch import nn | |
import torch | |
def make_unit_length(x, epsilon=1e-6): | |
norm = x.norm(p=2, dim=-1, keepdim=True) | |
return x.div(norm + epsilon) | |
def sort_key_val(t1, t2, dim=-1): | |
values, indices = t1.sort(dim=dim) | |
t2 = t2.expand_as(t1) | |
return values, t2.gather(dim, indices) | |
def batched_index_select(values, indices): | |
b = values.shape[0] | |
return values[torch.arange(0, b), indices.transpose(0, 1)].transpose(0, 1) | |
class LSHAttention(nn.Module): | |
def __init__( self, | |
dropout = 0., | |
bucket_size = 64, | |
n_hashes = 8, | |
allow_duplicate_attention = False, | |
attend_across_buckets = False, | |
rehash_each_round = True, | |
drop_for_hash_rate = 0.0): | |
super().__init__() | |
if dropout >= 1.0: | |
raise ValueError('Dropout rates must be lower than 1.') | |
self.dropout = nn.Dropout(dropout) | |
self.dropout_for_hash = nn.Dropout(drop_for_hash_rate) | |
assert rehash_each_round or allow_duplicate_attention, ( | |
'The setting {allow_duplicate_attention=False, rehash_each_round=False}' | |
' is not implemented.') | |
self.n_hashes = n_hashes | |
self.bucket_size = bucket_size | |
self._allow_duplicate_attention = allow_duplicate_attention | |
self._attend_across_buckets = attend_across_buckets | |
self._rehash_each_round = rehash_each_round | |
def _sample_rotation(self, shape, vecs): | |
device = vecs.device | |
return torch.randn(shape, device=device) | |
def hash_vectors(self, n_buckets, vecs): | |
batch_size = vecs.shape[0] | |
device = vecs.device | |
# See https://arxiv.org/pdf/1509.02897.pdf | |
# We sample a different random rotation for each round of hashing to | |
# decrease the probability of hash misses. | |
assert n_buckets % 2 == 0 | |
rot_size = n_buckets | |
rotations_shape = ( | |
vecs.shape[-1], | |
self.n_hashes if self._rehash_each_round else 1, | |
rot_size // 2) | |
random_rotations = self._sample_rotation(rotations_shape, vecs) | |
dropped_vecs = self.dropout_for_hash(vecs) | |
rotated_vecs = torch.einsum('btf,fhi->bhti', dropped_vecs, random_rotations) | |
if self._rehash_each_round: | |
rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1) | |
buckets = torch.argmax(rotated_vecs, axis=-1) | |
# buckets is now (self.n_hashes, seqlen). Next we add offsets so that | |
# bucket numbers from different hashing rounds don't overlap. | |
offsets = torch.arange(self.n_hashes, device=device) | |
offsets = torch.reshape(offsets * n_buckets, (1, -1, 1)) | |
buckets = torch.reshape(buckets + offsets, (batch_size, -1,)) | |
else: | |
assert not self._factorize_hash | |
rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1) | |
# In this configuration, we map each item to the top self.n_hashes buckets | |
rotated_vecs = torch.squeeze(rotated_vecs, 0) | |
bucket_range = torch.arange(0, rotated_vecs.shape[-1], device=device) | |
bucket_range = torch.reshape(bucket_range, (1, -1)) | |
bucket_range = bucket_range.expand_as(rotated_vecs.shape) | |
_, buckets = sort_key_val(rotated_vecs, bucket_range, dim=-1) | |
buckets = buckets[:, -self.n_hashes:] | |
h, *_ = buckets.shape | |
buckets = torch.reshape(buckets.permute((*_, h)), (-1,)) | |
return buckets | |
def forward(self, qk, v): | |
batch_size, seqlen, _ = qk.shape | |
device = qk.device | |
n_buckets = seqlen // self.bucket_size | |
n_bins = n_buckets | |
buckets = self.hash_vectors(n_buckets, qk) | |
# We use the same vector as both a query and a key. | |
assert int(buckets.shape[1]) == self.n_hashes * seqlen | |
ticker = torch.arange(0, self.n_hashes * seqlen, device=device).unsqueeze(0) | |
buckets_and_t = seqlen * buckets + (ticker % seqlen) | |
buckets_and_t = buckets_and_t.detach() | |
# Hash-based sort ("s" at the start of variable names means "sorted") | |
sbuckets_and_t, sticker = sort_key_val(buckets_and_t, ticker, dim=-1) | |
_, undo_sort = sort_key_val(sticker, ticker, dim=-1) | |
sbuckets_and_t = sbuckets_and_t.detach() | |
sticker = sticker.detach() | |
undo_sort = undo_sort.detach() | |
st = (sticker % seqlen) | |
sqk = batched_index_select(qk, st) | |
sv = batched_index_select(v, st) | |
# Split off a "bin" axis so that attention only occurs within chunks. | |
bq_t = bkv_t = torch.reshape(st, (batch_size, self.n_hashes * n_bins, -1)) | |
bqk = torch.reshape(sqk, (batch_size, self.n_hashes * n_bins, -1, sqk.shape[-1])) | |
bv = torch.reshape(sv, (batch_size, self.n_hashes * n_bins, -1, sv.shape[-1])) | |
bq_buckets = bkv_buckets = torch.reshape(sbuckets_and_t // seqlen, (batch_size, self.n_hashes * n_bins, -1)) | |
# Hashing operates on unit-length vectors. Unnormalized query vectors are | |
# fine because they effectively provide a learnable temperature for the | |
# attention softmax, but normalizing keys is needed so that similarity for | |
# the purposes of attention correctly corresponds to hash locality. | |
bq = bqk | |
bk = make_unit_length(bqk) | |
# Allow each chunk to attend within itself, and also one chunk back. Chunk | |
# boundaries might occur in the middle of a sequence of items from the | |
# same bucket, so this increases the chances of attending to relevant items. | |
def look_one_back(x): | |
x_extra = torch.cat([x[:, -1:, ...], x[:, :-1, ...]], dim=1) | |
return torch.cat([x, x_extra], dim=2) | |
bk = look_one_back(bk) | |
bv = look_one_back(bv) | |
bkv_t = look_one_back(bkv_t) | |
bkv_buckets = look_one_back(bkv_buckets) | |
# Dot-product attention. | |
dots = torch.einsum('bhie,bhje->bhij', bq, bk) / (bq.shape[-1] ** -0.5) | |
# Causal masking | |
mask = bq_t[:, :, :, None] < bkv_t[:, :, None, :] | |
dots = dots - 1e9 * mask | |
# Mask out attention to self except when no other targets are available. | |
self_mask = bq_t[:, :, :, None] == bkv_t[:, :, None, :] | |
dots = dots - 1e5 * self_mask | |
# Mask out attention to other hash buckets. | |
if not self._attend_across_buckets: | |
bucket_mask = bq_buckets[:, :, :, None] != bkv_buckets[:, :, None, :] | |
dots = dots - 1e7 * bucket_mask | |
# Don't double-count query-key pairs across multiple rounds of hashing. | |
# There are two possible strategies here. (1) The default is to count how | |
# many times a query-key pair is repeated, and to lower its log-prob | |
# correspondingly at each repetition. (2) When hard_k is set, the code | |
# instead masks all but the first occurence of each query-key pair. | |
if not self._allow_duplicate_attention: | |
locs1 = undo_sort // bq_t.shape[-1] | |
locs2 = (locs1 + 1) % (self.n_hashes * n_bins) | |
if not self._attend_across_buckets: | |
locs1 = buckets * (self.n_hashes * n_bins) + locs1 | |
locs2 = buckets * (self.n_hashes * n_bins) + locs2 | |
locs = torch.cat([ | |
torch.reshape(locs1, (batch_size, self.n_hashes, seqlen)), | |
torch.reshape(locs2, (batch_size, self.n_hashes, seqlen)), | |
], 1).permute((0, 2, 1)) | |
slocs = batched_index_select(locs, st) | |
b_locs = torch.reshape(slocs, (batch_size, self.n_hashes * n_bins, -1, 2 * self.n_hashes)) | |
b_locs1 = b_locs[:, :, :, None, :self.n_hashes] | |
bq_locs = b_locs1.expand(b_locs.shape[:3] + (2, self.n_hashes)) | |
bq_locs = torch.reshape(bq_locs, b_locs.shape) | |
bkv_locs = look_one_back(b_locs) | |
dup_counts = (bq_locs[:, :, :, None, :] == bkv_locs[:, :, None, :, :]).float().sum(dim=-1) | |
dup_counts = dup_counts.detach() | |
assert dup_counts.shape == dots.shape | |
dots = dots - torch.log(dup_counts + 1e-9) | |
# Softmax. | |
dots_logsumexp = torch.logsumexp(dots, dim=-1, keepdim=True) | |
dots = torch.exp(dots - dots_logsumexp) | |
dots = self.dropout(dots) | |
bo = torch.einsum('buij,buje->buie', dots, bv) | |
so = torch.reshape(bo, (batch_size, -1, bo.shape[-1])) | |
slogits = torch.reshape(dots_logsumexp, (batch_size, -1,)) | |
o = batched_index_select(so, undo_sort) | |
_, logits = sort_key_val(sticker, slogits, dim=-1) | |
if self.n_hashes == 1: | |
out = o | |
else: | |
o = torch.reshape(o, (batch_size, self.n_hashes, seqlen, o.shape[-1])) | |
logits = torch.reshape(logits, (batch_size, self.n_hashes, seqlen, 1)) | |
probs = torch.exp(logits - torch.logsumexp(logits, dim=1, keepdims=True)) | |
out = torch.sum(o * probs, dim=1) | |
assert out.shape == v.shape | |
return out | |
class LSHSelfAttention(nn.Module): | |
def __init__(self, emb, heads = 8, bucket_size = 64, n_hashes = 8, **kwargs): | |
super().__init__() | |
self.heads = heads | |
self.toqk = nn.Linear(emb, emb * heads) | |
self.tov = nn.Linear(emb, emb * heads) | |
self.unify_heads = nn.Linear(emb * heads, emb) | |
self.bucket_size = bucket_size | |
self.lsh_attn = LSHAttention(bucket_size=bucket_size, **kwargs) | |
def forward(self, x): | |
b, t, e, h = *x.shape, self.heads | |
assert t % self.bucket_size == 0, f'Sequence length needs to be divisible by target bucket size - {self.bucket_size}' | |
qk = self.toqk(x) | |
v = self.tov(x) | |
def merge_heads(v): | |
return v.view(b, t, h, e).transpose(1, 2).reshape(b * h, t, e) | |
def split_heads(v): | |
return v.view(b, h, t, e).transpose(1, 2).contiguous() | |
qk = merge_heads(qk) | |
v = merge_heads(v) | |
attn_out = self.lsh_attn(qk, v) | |
out = split_heads(attn_out).view(b, t, h * e) | |
return self.unify_heads(out) | |
# example | |
depth = 6 | |
seqlen = 4096 | |
emb = 64 | |
# class | |
LSH = LSHSelfAttention(64, heads = 8, bucket_size = 64, n_hashes = 8) | |
# forward pass | |
tokens = torch.randn((1, seqlen, emb)) | |
for i in range(depth): | |
tokens = LSH(tokens) + tokens |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment