Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Created October 10, 2024 16:19
Show Gist options
  • Save crowsonkb/b19b4fb27da8e64dc4e954c6ba71f8dd to your computer and use it in GitHub Desktop.
Save crowsonkb/b19b4fb27da8e64dc4e954c6ba71f8dd to your computer and use it in GitHub Desktop.
Ring attention for PyTorch.
"""Ring attention for PyTorch.
See https://github.com/nshepperd/flash_attn_jax/blob/main/src/flash_attn_jax/ring_attention.py.
"""
import flash_attn.flash_attn_interface as fai
import torch
import torch.distributed as dist
def ppermute(xs, perm, group=None):
rank = dist.get_rank(group)
ys = [torch.empty_like(x) for x in xs]
ops = []
for src, dst in perm:
for x, y in zip(xs, ys):
if src == rank:
ops.append(dist.P2POp(dist.isend, x, dst, group))
if dst == rank:
ops.append(dist.P2POp(dist.irecv, y, src, group))
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
return ys
def _flash_fwd(q, k, v, causal):
ret = fai._flash_attn_forward(
q=q,
k=k,
v=v,
dropout_p=0.0,
softmax_scale=k.shape[-1]**-0.5,
causal=causal,
window_size=(-1, 0) if causal else (-1, -1),
softcap=0.0,
alibi_slopes=None,
return_softmax=False,
)
return ret[0], ret[5] # out, lse
def _flash_bwd(do, q, k, v, o, lse, causal):
ret = fai._flash_attn_backward(
dout=do,
q=q,
k=k,
v=v,
out=o,
softmax_lse=lse,
dq=torch.empty_like(q),
dk=torch.empty_like(k),
dv=torch.empty_like(v),
dropout_p=0,
softmax_scale=k.shape[-1]**-0.5,
causal=causal,
window_size=(-1, 0) if causal else (-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=False,
rng_state=None,
)
return ret[0], ret[1], ret[2] # dq, dk, dv
def _ring_fwd(q, k, v, causal=False, group=None):
rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
perm = [(i, (i + 1) % world_size) for i in range(world_size)]
n, s, h, d = q.shape
q_ix = torch.tensor(rank, device=q.device)
k_ix = torch.tensor(rank, device=q.device)
o = torch.zeros_like(q, dtype=torch.float32)
lse = torch.full((n, h, s), float("-inf"), device=q.device, dtype=torch.float32)
for _ in range(world_size):
o1, lse1 = o, lse
if not causal:
o2, lse2 = _flash_fwd(q, k, v, causal=False)
else:
if q_ix < k_ix:
o2 = torch.zeros_like(q)
lse2 = torch.full((n, h, s), float("-inf"), device=q.device, dtype=torch.float32)
elif q_ix == k_ix:
o2, lse2 = _flash_fwd(q, k, v, causal=True)
else:
o2, lse2 = _flash_fwd(q, k, v, causal=False)
lse = torch.logaddexp(lse1, lse2)
o = o1 * torch.exp(lse1 - lse).mT[..., None] + o2 * torch.exp(lse2 - lse).mT[..., None]
k, v, k_ix = ppermute([k, v, k_ix], perm, group)
return o.to(q.dtype), lse
def _ring_bwd(do, q, k, v, o, lse, causal=False, group=None):
rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
perm = [(i, (i + 1) % world_size) for i in range(world_size)]
ix = torch.tensor(rank, device=q.device)
dq = torch.zeros_like(q, dtype=torch.float32)
dk = torch.zeros_like(k, dtype=torch.float32)
dv = torch.zeros_like(v, dtype=torch.float32)
k2, v2, dk2, dv2, ix2 = k, v, dk, dv, ix
for _ in range(world_size):
dk2_, dv2_, k2_, v2_, ix2_ = ppermute([dk2, dv2, k2, v2, ix2], perm, group)
if not causal:
dqa, dka, dva = _flash_bwd(do, q, k2, v2, o, lse, causal=False)
dq += dqa
dk2_ += dka
dv2_ += dva
else:
if ix == ix2:
dqa, dka, dva = _flash_bwd(do, q, k2, v2, o, lse, causal=True)
elif ix > ix2:
dqa, dka, dva = _flash_bwd(do, q, k2, v2, o, lse, causal=False)
if ix >= ix2:
dq += dqa
dk2_ += dka
dv2_ += dva
k2, v2, dk2, dv2, ix2 = k2_, v2_, dk2_, dv2_, ix2_
dk2, dv2 = ppermute([dk2, dv2], perm)
return dq.to(q.dtype), dk2.to(k.dtype), dv2.to(v.dtype)
class _RingAttention(torch.autograd.Function):
@staticmethod
def setup_context(ctx, inputs, output):
q, k, v, causal, group = inputs
o, lse = output
ctx.causal = causal
ctx.group = group
ctx.save_for_backward(q, k, v, o, lse)
@staticmethod
def forward(q, k, v, causal, group):
return _ring_fwd(q, k, v, causal=causal, group=group)
@staticmethod
def backward(ctx, do, _):
q, k, v, o, lse = ctx.saved_tensors
dq, dk, dv = _ring_bwd(do, q, k, v, o, lse, causal=ctx.causal, group=ctx.group)
return dq, dk, dv, None, None
def ring_attn(q, k, v, causal=False, group=None):
o, lse = _RingAttention.apply(q, k, v, causal, group)
return o
#!/usr/bin/env python3
import flash_attn
from ring_attn import ppermute, ring_attn
import torch
from torch import distributed as dist
from torch.distributed import nn as dnn
import torch_dist_utils as du
def main():
du.init_distributed()
device = du.get_device()
rank = dist.get_rank()
world_size = dist.get_world_size()
# test ppermute
du.print0("Testing ppermute...")
x = torch.arange(rank * 4, (rank + 1) * 4, device=device)
perm = [(i, (i + 1) % world_size) for i in range(world_size)]
y = ppermute([x], perm)
with du.do_in_order():
print(f"Rank {rank}: x = {x}, y = {y}")
q = torch.randn(4, 10, 8, 64, device=device, dtype=torch.bfloat16)
k = torch.randn(4, 10, 4, 64, device=device, dtype=torch.bfloat16)
v = torch.randn(4, 10, 4, 64, device=device, dtype=torch.bfloat16)
do = torch.randn(4, 10, 8, 64, device=device, dtype=torch.bfloat16)
q_all = torch.cat(dnn.all_gather(q), dim=1)
k_all = torch.cat(dnn.all_gather(k), dim=1)
v_all = torch.cat(dnn.all_gather(v), dim=1)
do_all = torch.cat(dnn.all_gather(do), dim=1)
# non-causal
du.print0("Testing non-causal ring attention...")
q_all_ = q_all.clone().requires_grad_()
k_all_ = k_all.clone().requires_grad_()
v_all_ = v_all.clone().requires_grad_()
o_ref = flash_attn.flash_attn_func(q_all_, k_all_, v_all_, causal=False)
o_ref.backward(do_all)
q_ = q.clone().requires_grad_()
k_ = k.clone().requires_grad_()
v_ = v.clone().requires_grad_()
o = ring_attn(q_, k_, v_, causal=False)
o.backward(do)
o_all = torch.cat(dnn.all_gather(o), dim=1)
dq_all = torch.cat(dnn.all_gather(q_.grad), dim=1)
dk_all = torch.cat(dnn.all_gather(k_.grad), dim=1)
dv_all = torch.cat(dnn.all_gather(v_.grad), dim=1)
error_o = torch.sqrt(torch.mean((o_all - o_ref) ** 2))
error_dq = torch.sqrt(torch.mean((q_all_.grad - dq_all) ** 2))
error_dk = torch.sqrt(torch.mean((k_all_.grad - dk_all) ** 2))
error_dv = torch.sqrt(torch.mean((v_all_.grad - dv_all) ** 2))
with du.do_in_order():
print(f"Rank {rank}: error o = {error_o}")
print(f"Rank {rank}: error dq = {error_dq}")
print(f"Rank {rank}: error dk = {error_dk}")
print(f"Rank {rank}: error dv = {error_dv}")
# causal
du.print0("Testing causal ring attention...")
q_all_ = q_all.clone().requires_grad_()
k_all_ = k_all.clone().requires_grad_()
v_all_ = v_all.clone().requires_grad_()
o_ref = flash_attn.flash_attn_func(q_all_, k_all_, v_all_, causal=True)
o_ref.backward(do_all)
q_ = q.clone().requires_grad_()
k_ = k.clone().requires_grad_()
v_ = v.clone().requires_grad_()
o = ring_attn(q_, k_, v_, causal=True)
o.backward(do)
o_all = torch.cat(dnn.all_gather(o), dim=1)
dq_all = torch.cat(dnn.all_gather(q_.grad), dim=1)
dk_all = torch.cat(dnn.all_gather(k_.grad), dim=1)
dv_all = torch.cat(dnn.all_gather(v_.grad), dim=1)
error_o = torch.sqrt(torch.mean((o_all - o_ref) ** 2))
error_dq = torch.sqrt(torch.mean((q_all_.grad - dq_all) ** 2))
error_dk = torch.sqrt(torch.mean((k_all_.grad - dk_all) ** 2))
error_dv = torch.sqrt(torch.mean((v_all_.grad - dv_all) ** 2))
with du.do_in_order():
print(f"Rank {rank}: error o = {error_o}")
print(f"Rank {rank}: error dq = {error_dq}")
print(f"Rank {rank}: error dk = {error_dk}")
print(f"Rank {rank}: error dv = {error_dv}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment