Last active
March 25, 2020 17:29
-
-
Save lucidrains/e78b7805981252f3a901398bc678afe7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
# helpers | |
def make_unit_length(x, epsilon=1e-6): | |
norm = x.norm(p=2, dim=-1, keepdim=True) | |
return x.div(norm + epsilon) | |
def sort_key_val(t1, t2, dim=-1): | |
values, indices = t1.sort(dim=dim) | |
t2 = t2.expand_as(t1) | |
return values, t2.gather(dim, indices) | |
def batched_index_select(values, indices): | |
b = values.shape[0] | |
return values[torch.arange(0, b), indices.transpose(0, 1)].transpose(0, 1) | |
# reversible net helper classes | |
class ReversibleBlock(nn.Module): | |
def __init__(self, f_block, g_block, dim = 1): | |
super().__init__() | |
self.dim = dim | |
self.f_block = f_block | |
self.g_block = g_block | |
def forward(self, x): | |
x1, x2 = torch.chunk(x, 2, dim=self.dim) | |
y1, y2 = None, None | |
with torch.no_grad(): | |
y1 = x1 + self.f_block(x2) | |
y2 = x2 + self.g_block(y1) | |
return torch.cat([y1, y2], dim=self.dim) | |
def backward_pass(self, y, dy): | |
y1, y2 = torch.chunk(y, 2, dim=self.dim) | |
del y | |
dy1, dy2 = torch.chunk(dy, 2, dim=self.dim) | |
del dy | |
y1.requires_grad = True | |
y2.requires_grad = True | |
with torch.enable_grad(): | |
gy1 = self.g_block(y1) | |
gy1.backward(dy2) | |
with torch.no_grad(): | |
x2 = y2 - gy1 | |
del y2, gy1 | |
dx1 = dy1 + y1.grad | |
del dy1 | |
y1.grad = None | |
with torch.enable_grad(): | |
x2.requires_grad = True | |
fx2 = self.f_block(x2) | |
fx2.backward(dx1) | |
with torch.no_grad(): | |
x1 = y1 - fx2 | |
del y1, fx2 | |
dx2 = dy2 + x2.grad | |
del dy2 | |
x2.grad = None | |
x = torch.cat([x1, x2.detach()], dim=self.dim) | |
dx = torch.cat([dx1, dx2], dim=self.dim) | |
return x, dx | |
class _ReversibleModuleFunction(torch.autograd.function.Function): | |
@staticmethod | |
def forward(ctx, x, reversible_blocks): | |
for block in reversible_blocks: | |
x = block(x) | |
ctx.y = x.detach() | |
ctx.reversible_blocks = reversible_blocks | |
return x | |
@staticmethod | |
def backward(ctx, dy): | |
y = ctx.y | |
del ctx.y | |
for i in range(len(ctx.reversible_blocks) - 1, -1, -1): | |
y, dy = ctx.reversible_blocks[i].backward_pass(y, dy) | |
del ctx.reversible_blocks | |
return dy, None | |
class ReversibleSequence(nn.Module): | |
def __init__(self, reversible_blocks): | |
super().__init__() | |
self.reversible_blocks = reversible_blocks | |
def forward(self, x): | |
x = _ReversibleModuleFunction.apply(x, self.reversible_blocks) | |
return x | |
# lsh attention | |
class LSHAttention(nn.Module): | |
def __init__( self, | |
dropout = 0., | |
bucket_size = 64, | |
n_hashes = 8, | |
allow_duplicate_attention = False, | |
attend_across_buckets = False, | |
rehash_each_round = True, | |
drop_for_hash_rate = 0.0): | |
super().__init__() | |
if dropout >= 1.0: | |
raise ValueError('Dropout rates must be lower than 1.') | |
self.dropout = nn.Dropout(dropout) | |
self.dropout_for_hash = nn.Dropout(drop_for_hash_rate) | |
assert rehash_each_round or allow_duplicate_attention, ( | |
'The setting {allow_duplicate_attention=False, rehash_each_round=False}' | |
' is not implemented.') | |
self.n_hashes = n_hashes | |
self.bucket_size = bucket_size | |
self._allow_duplicate_attention = allow_duplicate_attention | |
self._attend_across_buckets = attend_across_buckets | |
self._rehash_each_round = rehash_each_round | |
def _sample_rotation(self, shape, vecs): | |
device = vecs.device | |
return torch.randn(shape, device=device) | |
def hash_vectors(self, n_buckets, vecs): | |
batch_size = vecs.shape[0] | |
device = vecs.device | |
# See https://arxiv.org/pdf/1509.02897.pdf | |
# We sample a different random rotation for each round of hashing to | |
# decrease the probability of hash misses. | |
assert n_buckets % 2 == 0 | |
rot_size = n_buckets | |
rotations_shape = ( | |
vecs.shape[-1], | |
self.n_hashes if self._rehash_each_round else 1, | |
rot_size // 2) | |
random_rotations = self._sample_rotation(rotations_shape, vecs) | |
dropped_vecs = self.dropout_for_hash(vecs) | |
rotated_vecs = torch.einsum('btf,fhi->bhti', dropped_vecs, random_rotations) | |
if self._rehash_each_round: | |
rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1) | |
buckets = torch.argmax(rotated_vecs, axis=-1) | |
# buckets is now (self.n_hashes, seqlen). Next we add offsets so that | |
# bucket numbers from different hashing rounds don't overlap. | |
offsets = torch.arange(self.n_hashes, device=device) | |
offsets = torch.reshape(offsets * n_buckets, (1, -1, 1)) | |
buckets = torch.reshape(buckets + offsets, (batch_size, -1,)) | |
else: | |
assert not self._factorize_hash | |
rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1) | |
# In this configuration, we map each item to the top self.n_hashes buckets | |
rotated_vecs = torch.squeeze(rotated_vecs, 0) | |
bucket_range = torch.arange(0, rotated_vecs.shape[-1], device=device) | |
bucket_range = torch.reshape(bucket_range, (1, -1)) | |
bucket_range = bucket_range.expand_as(rotated_vecs.shape) | |
_, buckets = sort_key_val(rotated_vecs, bucket_range, dim=-1) | |
buckets = buckets[:, -self.n_hashes:] | |
h, *_ = buckets.shape | |
buckets = torch.reshape(buckets.permute((*_, h)), (-1,)) | |
return buckets | |
def forward(self, qk, v): | |
batch_size, seqlen, _ = qk.shape | |
device = qk.device | |
n_buckets = seqlen // self.bucket_size | |
n_bins = n_buckets | |
buckets = self.hash_vectors(n_buckets, qk) | |
# We use the same vector as both a query and a key. | |
assert int(buckets.shape[1]) == self.n_hashes * seqlen | |
ticker = torch.arange(0, self.n_hashes * seqlen, device=device).unsqueeze(0) | |
buckets_and_t = seqlen * buckets + (ticker % seqlen) | |
buckets_and_t = buckets_and_t.detach() | |
# Hash-based sort ("s" at the start of variable names means "sorted") | |
sbuckets_and_t, sticker = sort_key_val(buckets_and_t, ticker, dim=-1) | |
_, undo_sort = sort_key_val(sticker, ticker, dim=-1) | |
sbuckets_and_t = sbuckets_and_t.detach() | |
sticker = sticker.detach() | |
undo_sort = undo_sort.detach() | |
st = (sticker % seqlen) | |
sqk = batched_index_select(qk, st) | |
sv = batched_index_select(v, st) | |
# Split off a "bin" axis so that attention only occurs within chunks. | |
bq_t = bkv_t = torch.reshape(st, (batch_size, self.n_hashes * n_bins, -1)) | |
bqk = torch.reshape(sqk, (batch_size, self.n_hashes * n_bins, -1, sqk.shape[-1])) | |
bv = torch.reshape(sv, (batch_size, self.n_hashes * n_bins, -1, sv.shape[-1])) | |
bq_buckets = bkv_buckets = torch.reshape(sbuckets_and_t // seqlen, (batch_size, self.n_hashes * n_bins, -1)) | |
# Hashing operates on unit-length vectors. Unnormalized query vectors are | |
# fine because they effectively provide a learnable temperature for the | |
# attention softmax, but normalizing keys is needed so that similarity for | |
# the purposes of attention correctly corresponds to hash locality. | |
bq = bqk | |
bk = make_unit_length(bqk) | |
# Allow each chunk to attend within itself, and also one chunk back. Chunk | |
# boundaries might occur in the middle of a sequence of items from the | |
# same bucket, so this increases the chances of attending to relevant items. | |
def look_one_back(x): | |
x_extra = torch.cat([x[:, -1:, ...], x[:, :-1, ...]], dim=1) | |
return torch.cat([x, x_extra], dim=2) | |
bk = look_one_back(bk) | |
bv = look_one_back(bv) | |
bkv_t = look_one_back(bkv_t) | |
bkv_buckets = look_one_back(bkv_buckets) | |
# Dot-product attention. | |
dots = torch.einsum('bhie,bhje->bhij', bq, bk) / (bq.shape[-1] ** -0.5) | |
# Causal masking | |
mask = bq_t[:, :, :, None] < bkv_t[:, :, None, :] | |
dots = dots - 1e9 * mask | |
# Mask out attention to self except when no other targets are available. | |
self_mask = bq_t[:, :, :, None] == bkv_t[:, :, None, :] | |
dots = dots - 1e5 * self_mask | |
# Mask out attention to other hash buckets. | |
if not self._attend_across_buckets: | |
bucket_mask = bq_buckets[:, :, :, None] != bkv_buckets[:, :, None, :] | |
dots = dots - 1e7 * bucket_mask | |
# Don't double-count query-key pairs across multiple rounds of hashing. | |
# There are two possible strategies here. (1) The default is to count how | |
# many times a query-key pair is repeated, and to lower its log-prob | |
# correspondingly at each repetition. (2) When hard_k is set, the code | |
# instead masks all but the first occurence of each query-key pair. | |
if not self._allow_duplicate_attention: | |
locs1 = undo_sort // bq_t.shape[-1] | |
locs2 = (locs1 + 1) % (self.n_hashes * n_bins) | |
if not self._attend_across_buckets: | |
locs1 = buckets * (self.n_hashes * n_bins) + locs1 | |
locs2 = buckets * (self.n_hashes * n_bins) + locs2 | |
locs = torch.cat([ | |
torch.reshape(locs1, (batch_size, self.n_hashes, seqlen)), | |
torch.reshape(locs2, (batch_size, self.n_hashes, seqlen)), | |
], 1).permute((0, 2, 1)) | |
slocs = batched_index_select(locs, st) | |
b_locs = torch.reshape(slocs, (batch_size, self.n_hashes * n_bins, -1, 2 * self.n_hashes)) | |
b_locs1 = b_locs[:, :, :, None, :self.n_hashes] | |
bq_locs = b_locs1.expand(b_locs.shape[:3] + (2, self.n_hashes)) | |
bq_locs = torch.reshape(bq_locs, b_locs.shape) | |
bkv_locs = look_one_back(b_locs) | |
dup_counts = (bq_locs[:, :, :, None, :] == bkv_locs[:, :, None, :, :]).float().sum(dim=-1) | |
dup_counts = dup_counts.detach() | |
assert dup_counts.shape == dots.shape | |
dots = dots - torch.log(dup_counts + 1e-9) | |
# Softmax. | |
dots_logsumexp = torch.logsumexp(dots, dim=-1, keepdim=True) | |
dots = torch.exp(dots - dots_logsumexp) | |
dots = self.dropout(dots) | |
bo = torch.einsum('buij,buje->buie', dots, bv) | |
so = torch.reshape(bo, (batch_size, -1, bo.shape[-1])) | |
slogits = torch.reshape(dots_logsumexp, (batch_size, -1,)) | |
o = batched_index_select(so, undo_sort) | |
_, logits = sort_key_val(sticker, slogits, dim=-1) | |
if self.n_hashes == 1: | |
out = o | |
else: | |
o = torch.reshape(o, (batch_size, self.n_hashes, seqlen, o.shape[-1])) | |
logits = torch.reshape(logits, (batch_size, self.n_hashes, seqlen, 1)) | |
probs = torch.exp(logits - torch.logsumexp(logits, dim=1, keepdims=True)) | |
out = torch.sum(o * probs, dim=1) | |
assert out.shape == v.shape | |
return out | |
class LSHSelfAttention(nn.Module): | |
def __init__(self, emb, heads = 8, bucket_size = 64, n_hashes = 8, **kwargs): | |
super().__init__() | |
self.heads = heads | |
self.toqk = nn.Linear(emb, emb * heads) | |
self.tov = nn.Linear(emb, emb * heads) | |
self.unify_heads = nn.Linear(emb * heads, emb) | |
self.bucket_size = bucket_size | |
self.lsh_attn = LSHAttention(bucket_size=bucket_size, **kwargs) | |
def forward(self, x): | |
b, t, e, h = *x.shape, self.heads | |
assert t % self.bucket_size == 0, f'Sequence length needs to be divisible by target bucket size - {self.bucket_size}' | |
qk = self.toqk(x) | |
v = self.tov(x) | |
def merge_heads(v): | |
return v.view(b, t, h, e).transpose(1, 2).reshape(b * h, t, e) | |
def split_heads(v): | |
return v.view(b, h, t, e).transpose(1, 2).contiguous() | |
qk = merge_heads(qk) | |
v = merge_heads(v) | |
attn_out = self.lsh_attn(qk, v) | |
out = split_heads(attn_out).view(b, t, h * e) | |
return self.unify_heads(out) | |
# feedforward with chunking | |
class FeedForward(nn.Module): | |
def __init__(self, emb, mult = 4): | |
super().__init__() | |
self.emb = emb | |
self.proj_in = nn.Linear(emb, emb * mult) | |
self.proj_out = nn.Linear(emb * mult, emb) | |
def forward(self, x): | |
x = self.proj_in(x) | |
x = F.gelu(x) | |
x = self.proj_out(x) | |
return x | |
class WithLayerNorm(nn.Module): | |
def __init__(self, emb, fn): | |
super().__init__() | |
self.emb = emb | |
self.norm = nn.LayerNorm(emb) | |
self.fn = fn | |
def forward(self, x): | |
x = self.norm(x) | |
return self.fn(x) | |
class Chunk(nn.Module): | |
def __init__(self, chunks, fn, dim = -1): | |
super().__init__() | |
self.dim = dim | |
self.chunks = chunks | |
self.fn = fn | |
def forward(self, x): | |
chunks = x.chunk(self.chunks, dim = self.dim) | |
return torch.cat([self.fn(c) for c in chunks], dim = self.dim) | |
# reformer auto-regressive lm | |
class Reformer(nn.Module): | |
def __init__(self, emb, depth, max_seq_len, num_tokens = 10000, heads = 8, bucket_size = 64, n_hashes = 8, ff_chunks = 100): | |
super().__init__() | |
self.emb = emb | |
self.depth = depth | |
self.token_emb = nn.Embedding(num_tokens, emb) | |
self.pos_emb = nn.Embedding(max_seq_len, emb) | |
blocks = [] | |
for _ in range(depth): | |
f = WithLayerNorm(emb, LSHSelfAttention(emb, heads, bucket_size, n_hashes)) | |
g = Chunk(ff_chunks, WithLayerNorm(emb, FeedForward(emb)), dim = -2) | |
blocks.append(ReversibleBlock(f, g, dim=-1)) | |
self.layers = ReversibleSequence(nn.ModuleList(blocks)) | |
self.to_logits = nn.Linear(emb, num_tokens) | |
def forward(self, x): | |
x = self.token_emb(x) + self.pos_emb(torch.arange(0, x.shape[1])) | |
x = torch.cat([x, x], dim = -1) | |
x = self.layers(x) | |
x = torch.stack(x.chunk(2, dim=-1)).sum(dim=0) | |
return self.to_logits(x) | |
# testing | |
num_tokens = 10000 | |
r = Reformer( | |
emb = 512, | |
depth = 12, | |
max_seq_len = 1024, | |
num_tokens= num_tokens, | |
heads = 8, | |
bucket_size = 64, | |
n_hashes = 8, | |
ff_chunks = 200 | |
) | |
x = torch.randint(0, num_tokens, (1, 1024)).long() | |
y = r(x) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment