Created
August 12, 2024 17:48
-
-
Save lucidrains/fe5d6c2c896bf1f935acf63d45339916 to your computer and use it in GitHub Desktop.
Tree Attention Decoding
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import einsum | |
import torch.distributed as dist | |
def tree_attn_decode(q, k, v): | |
""" | |
Algorithm 3 proposed in Tree Attention | |
https://arxiv.org/abs/2408.04093 | |
""" | |
rank = dist.get_rank() if dist.is_initialized() else 0 | |
world_size = dist.get_world_size() if dist.is_initialized() else 1 | |
# scale queries | |
scale = q.shape[-1] ** -0.5 | |
q = q * scale | |
# each machine (rank) takes care of a chunk of kv sequence within the world of many machines | |
k = k.chunk(world_size, dim = -2) | |
v = v.chunk(world_size, dim = -2) | |
k, v = k[rank], v[rank] | |
# first calculate local output | |
sim = einsum('... i d, ... j d -> ... i j', q, k) | |
local_max = sim.amax(dim = -1, keepdim = True) | |
sim = sim - local_max | |
lse = sim.logsumexp(dim = -1, keepdim = True) | |
attn = sim.softmax(dim = -1) | |
out = einsum('... i j, ... j d -> ... i d', attn, v) | |
den = lse.exp() | |
num = out * den | |
# first get global max through an all reduce (max) | |
global_max = local_max.clone() | |
dist.all_reduce(global_max, dist.ReduceOp.MAX) | |
# renormalize the numerator and denominators | |
renorm_factor = (local_max - global_max).exp() | |
den = den * renorm_factor | |
num = num * renorm_factor | |
# second and third all reduce (sum) | |
dist.all_reduce(den) | |
dist.all_reduce(num) | |
return num / den | |
# regular attention for testing | |
def regular_decode(q, k, v): | |
scale = q.shape[-1] ** -0.5 | |
q = q * scale | |
sim = einsum('... i d, ... j d -> ... i j', q, k) | |
attn = sim.softmax(dim = -1) | |
return einsum('... i j, ... j d -> ... i d', attn, v) | |
# for testing the above tree decoding function | |
# `pip install click` as requirement, besides `torch` | |
import os | |
import click | |
from math import ceil | |
import torch.multiprocessing as mp | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
def setup( | |
rank, | |
world_size, | |
use_cuda | |
): | |
os.environ['MASTER_ADDR'] = 'localhost' | |
os.environ['MASTER_PORT'] = '12355' | |
backend = "gloo" if not use_cuda else "nccl" | |
dist.init_process_group(backend, rank = rank, world_size = world_size) | |
if use_cuda: | |
torch.cuda.set_device(rank) | |
def cleanup(): | |
dist.destroy_process_group() | |
def start( | |
rank, | |
world_size, | |
seq_len, | |
use_cuda, | |
): | |
setup(rank, world_size, use_cuda) | |
is_main = rank == 0 | |
ring_seq_size = ceil(seq_len / world_size) | |
# inputs | |
q = torch.randn(1, 1, 512) | |
k = torch.randn(1, seq_len, 512) | |
v = torch.randn(1, seq_len, 512) | |
# easy forcing all q, k, v to be same across all device | |
dist.all_reduce(q) | |
dist.all_reduce(k) | |
dist.all_reduce(v) | |
# outputs | |
out = regular_decode(q, k, v) | |
tree_out = tree_attn_decode(q, k, v) | |
# if not main early return | |
if not is_main: | |
return cleanup() | |
# if is main, validate output is the same for kv sequence split across machines vs without | |
tree_out = tree_out.cpu() | |
out = out.cpu() | |
output_atol = 1e-2 if use_cuda else 1e-5 | |
assert torch.allclose(tree_out, out, atol = output_atol), '🟥 output is not the same' | |
print('✅ output is the same between tree and non-tree attention decoding') | |
cleanup() | |
@click.command() | |
@click.option('--world-size', default = 8, help = 'number of machines / processes') | |
@click.option('--use-cuda', is_flag = True, help = 'whether to test with CUDA and NCCL') | |
@click.option('--seq-len', default = 31, help = 'sequence length to test') | |
def test( | |
world_size: int, | |
use_cuda: bool, | |
seq_len: int, | |
): | |
assert not use_cuda or world_size <= torch.cuda.device_count(), f'world size {world_size} must be less than the number of cuda devices {torch.cuda.device_count()}' | |
mp.spawn( | |
start, | |
args = (world_size, seq_len, use_cuda), | |
nprocs = world_size, | |
join = True | |
) | |
if __name__ == '__main__': | |
test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment