Created
January 10, 2023 01:58
-
-
Save takuma104/29aa1eebd7e7fba611d8b1e872393114 to your computer and use it in GitHub Desktop.
memory_efficient_attention(jax) deterministic test
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# SELF-ATTENTION DOES NOT NEED O(n2) MEMORY: https://arxiv.org/pdf/2112.05682.pdf | |
# https://github.com/google-research/google-research/blob/master/memory_efficient_attention/memory_efficient_attention.ipynb | |
# you may need to apply 'export XLA_PYTHON_CLIENT_PREALLOCATE=false' | |
# https://github.com/google/jax/issues/7118#issuecomment-950183972 | |
import functools | |
import jax | |
import jax.numpy as jnp | |
import math | |
import numpy as np | |
from jax import jit | |
from tqdm import tqdm | |
def _query_chunk_attention(query, key, value, precision, key_chunk_size=4096): | |
"""Multi-head dot product attention with a limited number of queries.""" | |
num_kv, num_heads, k_features = key.shape | |
v_features = value.shape[-1] | |
key_chunk_size = min(key_chunk_size, num_kv) | |
query = query / jnp.sqrt(k_features) | |
@functools.partial(jax.checkpoint, prevent_cse=False) | |
def summarize_chunk(query, key, value): | |
attn_weights = jnp.einsum('qhd,khd->qhk', query, key, precision=precision) | |
max_score = jnp.max(attn_weights, axis=-1, keepdims=True) | |
max_score = jax.lax.stop_gradient(max_score) | |
exp_weights = jnp.exp(attn_weights - max_score) | |
exp_values = jnp.einsum('vhf,qhv->qhf', value, | |
exp_weights, precision=precision) | |
return (exp_values, exp_weights.sum(axis=-1), | |
max_score.reshape((query.shape[0], num_heads))) | |
def chunk_scanner(chunk_idx): | |
key_chunk = jax.lax.dynamic_slice( | |
key, (chunk_idx, 0, 0), | |
slice_sizes=(key_chunk_size, num_heads, k_features)) | |
value_chunk = jax.lax.dynamic_slice( | |
value, (chunk_idx, 0, 0), | |
slice_sizes=(key_chunk_size, num_heads, v_features)) | |
return summarize_chunk(query, key_chunk, value_chunk) | |
chunk_values, chunk_weights, chunk_max = jax.lax.map( | |
chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size)) | |
global_max = jnp.max(chunk_max, axis=0, keepdims=True) | |
max_diffs = jnp.exp(chunk_max - global_max) | |
chunk_values *= jnp.expand_dims(max_diffs, axis=-1) | |
chunk_weights *= max_diffs | |
all_values = chunk_values.sum(axis=0) | |
all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0) | |
return all_values / all_weights | |
@jit | |
def attention(query, key, value, precision=jax.lax.Precision.DEFAULT, | |
query_chunk_size=1024): | |
"""Memory-efficient multi-head dot product attention.""" | |
num_q, num_heads, q_features = query.shape | |
def chunk_scanner(chunk_idx, _): | |
query_chunk = jax.lax.dynamic_slice( | |
query, (chunk_idx, 0, 0), | |
slice_sizes=(min(query_chunk_size, num_q), num_heads, q_features)) | |
return (chunk_idx + query_chunk_size, | |
_query_chunk_attention(query_chunk, key, value, precision=precision)) | |
_, res = jax.lax.scan( | |
chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size)) | |
return res.reshape(num_q, num_heads, value.shape[-1]) | |
if __name__ == '__main__': | |
key = jax.random.PRNGKey(0) | |
dtype = np.float16 | |
shape = (1024, 16, 16) | |
key, subkey = jax.random.split(key) | |
q = jax.random.normal(subkey, shape=shape, dtype=dtype) | |
key, subkey = jax.random.split(key) | |
k = jax.random.normal(subkey, shape=shape, dtype=dtype) | |
key, subkey = jax.random.split(key) | |
v = jax.random.normal(subkey, shape=shape, dtype=dtype) | |
ref = attention(q, k, v) | |
for _ in tqdm(range(10000)): | |
r = attention(q, k, v) | |
not_same_value_count = int((ref != r).sum()) | |
assert not_same_value_count == 0, f'{not_same_value_count}' | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment