- The computation of QK^T involves N^2 dot products, where N is the sequence length.
- The softmax(QK^T)V operation requires an N x N attention matrix, making complexity quadratic in N.
Repeated theme: use lighter "augmentation" of attention (e.g. clustering, lsh lookup, etc) to vet only highly relevant tokens for expensive attention.
Reformer approximates the attention mechanism using Locality-Sensitive Hashing (LSH) to reduce its quadratic complexity. This is achieved by finding subsets of tokens to perform attention on, based on their similarity.
- Create LSH buckets for all keys of tokens, using random projections to hash keys into lower-dimensional spaces.
- Perform LSH lookups to group queries with keys that fall into the same bucket.
- Compute attention only within these selected buckets, reducing the overall complexity from O(N^2) to O(N log N).
This approach efficiently narrows the scope of attention while preserving approximate results.
To do this, they maintain a databaes of keys, values, and query for all tokens which is recomputed during each forward pass.
- Then take the lsh of the keys so you can perform your lookup.
- Significantly reduce the number of tokens you ar eattending to by only looking inside your hash bucket.
Employs sparse attention to research each token to attend to only a subset of the tokens in the sequence.
Uses a routing mechanism to cluster tokens in the search space
- Only pull tokens in the same cluster for your self-attention.
- K-means clustering based on key embedding
Divide the input into segments (size of full attention window). Pass each segment through the transformer and reuse the hidden states (activations) from the prior segment.
Specifically, each forward pass concatenates the current segment's hidden state with a carry hidden state:
- The carry hidden state is a fixed-size memory, containing the cached hidden states from the last (M) segments.
- This memory ensures that information from prior segments is included without performing full attention over all tokens.
- Gradients are only computed with current segment - carry state is treated as a stop gradient during training.k
Key techniques:
- Fixed Memory Size: The carry memory has a fixed size, making computation independent of total sequence length.
- Relative Positional Encodings: Ensure positional consistency for state reuse across segments.
- Gradient Management: Gradients are restricted to the current segment, with memory treated as stop-gradient.
- Evaluation Speed-Up: Memory reuse eliminates recomputation of prior segments, achieving massive speed-ups.
Key Benefit: Get to reuse prior context without calculating activations.
Using the projection of a transformer as the hidden state of a recurrent neural network.
- Divide tokens into blocks
- Perform attention only over blocks (linear scaling w seqlen)
- Perform cross-attention over transfomer projection with recurrent state at t-1.
- Accumulate projected blocks in hidden state of RNN
- Use a accumulated hidden state of an RNN to represent past blocks
- Perform cross-attention between current block self-attention projection at t and the hidden state at t-1 (concatenating and linearly projecting for hidden_state_t)
- Use a sliding-window between blocks.
- KV from previous blocks is cached, and a sliding window is used for the projection of the current block to maintain a larger attention context from prior blocks.
Philosophy: Similar to how humans read a book - we construct a mental model of the story thus far (main characters, relationships, plots, etc) rather than memorizing every single word.