Created
June 30, 2025 11:36
-
-
Save a-r-r-o-w/87926a348703e55f008f259a1778e4f3 to your computer and use it in GitHub Desktop.
sequential and templated ring/ulysses/unified attention implementation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
torch.manual_seed(42) | |
def torch_sdpa(query, key, value): | |
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( | |
torch.ops.aten._scaled_dot_product_cudnn_attention( | |
query=query, | |
key=key, | |
value=value, | |
attn_bias=None, | |
compute_log_sumexp=True, | |
) | |
) | |
return out, lse | |
def ring_sdpa_sequential(partial_queries, partial_keys, partial_values, *, world_size: int = 1, convert_to_fp32: bool = True): | |
outputs, lses = [], [] | |
for rank in range(world_size): | |
query, key, value = partial_queries[rank], partial_keys[rank], partial_values[rank] | |
next_rank = (rank + 1) % world_size | |
prev_out = prev_lse = None | |
for i in range(world_size): | |
if i > 0: | |
key, value = partial_keys[next_rank], partial_values[next_rank] | |
next_rank = (next_rank + 1) % world_size | |
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( | |
torch.ops.aten._scaled_dot_product_cudnn_attention( | |
query=query, | |
key=key, | |
value=value, | |
attn_bias=None, | |
compute_log_sumexp=True, | |
) | |
) | |
if convert_to_fp32: | |
out = out.to(torch.float32) | |
lse = lse.to(torch.float32) | |
# https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 | |
lse = lse.unsqueeze(-1) | |
if prev_out is not None: | |
out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out) | |
lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse) | |
prev_out = out | |
prev_lse = lse | |
out = out.to(query.dtype) | |
lse = lse.squeeze(-1) | |
outputs.append(out) | |
lses.append(lse) | |
return outputs, lses | |
device = "cuda" | |
dtype = torch.bfloat16 | |
world_size = 4 | |
batch_size = 1 | |
image_sequence_length = 4096 | |
text_sequence_length = 512 | |
sequence_length = image_sequence_length + text_sequence_length | |
num_attention_heads = 24 | |
attention_head_dim = 128 | |
query = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype) | |
key = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype) | |
value = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype) | |
partial_queries = query.chunk(world_size, dim=2) | |
partial_keys = key.chunk(world_size, dim=2) | |
partial_values = value.chunk(world_size, dim=2) | |
torch_sdpa_out, torch_sdpa_lse = torch_sdpa(query, key, value) | |
ring_sdpa_out, ring_sdpa_lse = ring_sdpa_sequential(partial_queries, partial_keys, partial_values, world_size=world_size) | |
all_ring_sdpa_out = torch.cat(ring_sdpa_out, dim=2) | |
all_ring_sdpa_lse = torch.cat(ring_sdpa_lse, dim=2) | |
assert torch_sdpa_out.shape == all_ring_sdpa_out.shape, "Output shapes do not match!" | |
assert torch_sdpa_lse.shape == all_ring_sdpa_lse.shape, "LSE shapes do not match!" | |
assert torch.allclose(all_ring_sdpa_out, torch_sdpa_out, atol=1e-3, rtol=1e-3), "Outputs do not match!" | |
assert torch.allclose(all_ring_sdpa_lse, torch_sdpa_lse, atol=1e-3, rtol=1e-3), "LSE values do not match!" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
torch.manual_seed(42) | |
def torch_sdpa(query, key, value): | |
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( | |
torch.ops.aten._scaled_dot_product_cudnn_attention( | |
query=query, | |
key=key, | |
value=value, | |
attn_bias=None, | |
compute_log_sumexp=True, | |
) | |
) | |
return out, lse | |
def ulysses_sdpa_sequential(partial_queries, partial_keys, partial_values, *, world_size: int = 1): | |
B, H, S_LOCAL, D = partial_queries[0].shape | |
H_LOCAL = H // world_size | |
outputs, lses = [], [] | |
for partials in [partial_queries, partial_keys, partial_values]: | |
for rank in range(world_size): | |
x_local = partials[rank] | |
# (B, H, S // world_size, D) -> (world_size, S // world_size, B, H // world_size, D) | |
partials[rank] = x_local.reshape(B, world_size, H_LOCAL, S_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() | |
x = all_to_all_single_sequential(partials, world_size) | |
for rank in range(world_size): | |
x_local = x[rank] | |
# (S, B, H // world_size, D) -> (B, H // world_size, S, D) | |
partials[rank] = x_local.permute(1, 2, 0, 3).contiguous() | |
for rank in range(world_size): | |
query_local, key_local, value_local = partial_queries[rank], partial_keys[rank], partial_values[rank] | |
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( | |
torch.ops.aten._scaled_dot_product_cudnn_attention( | |
query=query_local, | |
key=key_local, | |
value=value_local, | |
attn_bias=None, | |
compute_log_sumexp=True, | |
) | |
) | |
outputs.append(out) | |
lses.append(lse) | |
for rank in range(world_size): | |
out_local = outputs[rank] | |
lse_local = lses[rank] | |
# (B, H // world_size, S, D) -> (B, H // world_size, world_size, S // world_size, D) -> (world_size, H // world_size, B, S // world_size, D) | |
outputs[rank] = out_local.reshape(B, H_LOCAL, world_size, S_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() | |
lses[rank] = lse_local.reshape(B, H_LOCAL, world_size, S_LOCAL).permute(2, 1, 0, 3).contiguous() | |
outputs = all_to_all_single_sequential(outputs, world_size) | |
lses = all_to_all_single_sequential(lses, world_size) | |
for rank in range(world_size): | |
out_local = outputs[rank] | |
lse_local = lses[rank] | |
# (H, B, S // world_size, D) -> (B, H, S // world_size, D) | |
outputs[rank] = out_local.permute(1, 0, 2, 3).contiguous() | |
lses[rank] = lse_local.permute(1, 0, 2).contiguous() | |
return outputs, lses | |
def all_to_all_single_sequential(partials, world_size): | |
output_partials = [] | |
for i in range(world_size): | |
received_chunks = [p[i] for p in partials] | |
output_partials.append(torch.cat(received_chunks, dim=0)) | |
return output_partials | |
device = "cuda" | |
dtype = torch.bfloat16 | |
world_size = 4 | |
batch_size = 1 | |
image_sequence_length = 4096 | |
text_sequence_length = 512 | |
sequence_length = image_sequence_length + text_sequence_length | |
num_attention_heads = 24 | |
attention_head_dim = 128 | |
query = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype) | |
key = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype) | |
value = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype) | |
partial_queries = list(query.chunk(world_size, dim=2)) | |
partial_keys = list(key.chunk(world_size, dim=2)) | |
partial_values = list(value.chunk(world_size, dim=2)) | |
torch_sdpa_out, torch_sdpa_lse = torch_sdpa(query, key, value) | |
ulysses_sdpa_out, ulysses_sdpa_lse = ulysses_sdpa_sequential(partial_queries, partial_keys, partial_values, world_size=world_size) | |
all_ulysses_sdpa_out = torch.cat(ulysses_sdpa_out, dim=2) | |
all_ulysses_sdpa_lse = torch.cat(ulysses_sdpa_lse, dim=2) | |
assert torch_sdpa_out.shape == all_ulysses_sdpa_out.shape, "Output shapes do not match!" | |
assert torch_sdpa_lse.shape == all_ulysses_sdpa_lse.shape, "LSE shapes do not match!" | |
assert torch.allclose(all_ulysses_sdpa_out, torch_sdpa_out, atol=1e-3, rtol=1e-3), "Outputs do not match!" | |
assert torch.allclose(all_ulysses_sdpa_lse, torch_sdpa_lse, atol=1e-3, rtol=1e-3), "LSEs do not match!" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
torch.manual_seed(42) | |
def torch_sdpa(query, key, value): | |
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( | |
torch.ops.aten._scaled_dot_product_cudnn_attention( | |
query=query, | |
key=key, | |
value=value, | |
attn_bias=None, | |
compute_log_sumexp=True, | |
) | |
) | |
return out, lse | |
def ring_sdpa_sequential(partial_queries, partial_keys, partial_values, *, ring_size: int = 1, convert_to_fp32: bool = True): | |
outputs, lses = [], [] | |
for rank in range(ring_size): | |
query, key, value = partial_queries[rank], partial_keys[rank], partial_values[rank] | |
next_rank = (rank + 1) % ring_size | |
prev_out = prev_lse = None | |
for i in range(ring_size): | |
if i > 0: | |
key, value = partial_keys[next_rank], partial_values[next_rank] | |
next_rank = (next_rank + 1) % ring_size | |
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( | |
torch.ops.aten._scaled_dot_product_cudnn_attention( | |
query=query, | |
key=key, | |
value=value, | |
attn_bias=None, | |
compute_log_sumexp=True, | |
) | |
) | |
if convert_to_fp32: | |
out = out.to(torch.float32) | |
lse = lse.to(torch.float32) | |
# https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 | |
lse = lse.unsqueeze(-1) | |
if prev_out is not None: | |
out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out) | |
lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse) | |
prev_out = out | |
prev_lse = lse | |
out = out.to(query.dtype) | |
lse = lse.squeeze(-1) | |
outputs.append(out) | |
lses.append(lse) | |
return outputs, lses | |
def unified_ulysses_ring_sdpa_sequential(partial_queries, partial_keys, partial_values, *, ulysses_size: int = 1, ring_size: int = 1): | |
B, H, S_LOCAL, D = partial_queries[0][0].shape | |
H_LOCAL = H // ulysses_size | |
outputs, lses = [], [] | |
for partials in [partial_queries, partial_keys, partial_values]: | |
for ring_rank in range(ring_size): | |
for rank in range(ulysses_size): | |
x_local = partials[ring_rank][rank] | |
partials[ring_rank][rank] = x_local.reshape(B, ulysses_size, H_LOCAL, S_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() | |
x = all_to_all_single_sequential(partials[ring_rank], ulysses_size) | |
for rank in range(ulysses_size): | |
x_local = x[rank] | |
partials[ring_rank][rank] = x_local.permute(1, 2, 0, 3).contiguous() | |
partial_queries = [list(x) for x in zip(*partial_queries)] | |
partial_keys = [list(x) for x in zip(*partial_keys)] | |
partial_values = [list(x) for x in zip(*partial_values)] | |
for rank in range(ulysses_size): | |
ring_outputs, ring_lses = ring_sdpa_sequential(partial_queries[rank], partial_keys[rank], partial_values[rank], ring_size=ring_size) | |
outputs.append(ring_outputs) | |
lses.append(ring_lses) | |
outputs = [list(x) for x in zip(*outputs)] | |
lses = [list(x) for x in zip(*lses)] | |
for ring_rank in range(ring_size): | |
for rank in range(ulysses_size): | |
outputs[ring_rank][rank] = outputs[ring_rank][rank].reshape(B, H_LOCAL, ulysses_size, S_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() | |
lses[ring_rank][rank] = lses[ring_rank][rank].reshape(B, H_LOCAL, ulysses_size, S_LOCAL).permute(2, 1, 0, 3).contiguous() | |
outputs[ring_rank] = all_to_all_single_sequential(outputs[ring_rank], ulysses_size) | |
lses[ring_rank] = all_to_all_single_sequential(lses[ring_rank], ulysses_size) | |
for rank in range(ulysses_size): | |
outputs[ring_rank][rank] = outputs[ring_rank][rank].permute(1, 0, 2, 3).contiguous() | |
lses[ring_rank][rank] = lses[ring_rank][rank].permute(1, 0, 2).contiguous() | |
return outputs, lses | |
def all_to_all_single_sequential(partials, world_size): | |
output_partials = [] | |
for i in range(world_size): | |
received_chunks = [p[i] for p in partials] | |
output_partials.append(torch.cat(received_chunks, dim=0)) | |
return output_partials | |
device = "cuda" | |
dtype = torch.bfloat16 | |
WORLD_SIZE = 8 | |
ulysses_size = 4 | |
ring_size = 2 | |
assert ulysses_size * ring_size == WORLD_SIZE, "ulysses_size * ring_size must equal WORLD_SIZE" | |
batch_size = 1 | |
image_sequence_length = 4096 | |
text_sequence_length = 512 | |
sequence_length = image_sequence_length + text_sequence_length | |
num_attention_heads = 24 | |
attention_head_dim = 128 | |
query = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype) | |
key = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype) | |
value = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype) | |
partial_queries = list(query.chunk(WORLD_SIZE, dim=2)) | |
partial_keys = list(key.chunk(WORLD_SIZE, dim=2)) | |
partial_values = list(value.chunk(WORLD_SIZE, dim=2)) | |
# R=1, U=4 => [[tensor1, tensor2, tensor3, tensor4]] | |
# R=2, U=2 => [[tensor1, tensor2], [tensor3, tensor4]] | |
# R=4, U=1 => [[tensor1], [tensor2], [tensor3], [tensor4]] | |
partial_queries = [partial_queries[i:i + ulysses_size] for i in range(0, WORLD_SIZE, ulysses_size)] | |
partial_keys = [partial_keys[i:i + ulysses_size] for i in range(0, WORLD_SIZE, ulysses_size)] | |
partial_values = [partial_values[i:i + ulysses_size] for i in range(0, WORLD_SIZE, ulysses_size)] | |
torch_sdpa_out, torch_sdpa_lse = torch_sdpa(query, key, value) | |
unified_sdpa_out, unified_sdpa_lse = unified_ulysses_ring_sdpa_sequential(partial_queries, partial_keys, partial_values, ulysses_size=ulysses_size, ring_size=ring_size) | |
all_unified_sdpa_out = torch.cat([torch.cat(out, dim=2) for out in unified_sdpa_out], dim=2) | |
all_unified_sdpa_lse = torch.cat([torch.cat(lse, dim=2) for lse in unified_sdpa_lse], dim=2) | |
assert torch_sdpa_out.shape == all_unified_sdpa_out.shape, "Output shapes do not match!" | |
assert torch_sdpa_lse.shape == all_unified_sdpa_lse.shape, "LSE shapes do not match!" | |
assert torch.allclose(all_unified_sdpa_out, torch_sdpa_out, atol=1e-3, rtol=1e-3), "Outputs do not match!" | |
assert torch.allclose(all_unified_sdpa_lse, torch_sdpa_lse, atol=1e-3, rtol=1e-3), "LSEs do not match!" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
from dataclasses import dataclass | |
from typing import Callable, Literal, List | |
import torch | |
import torch.distributed as dist | |
import torch.distributed._functional_collectives as funcol | |
from torch.distributed import DeviceMesh | |
@dataclass | |
class ContextParallelOptions: | |
mode: Literal["ring", "ulysses", "unified"] = "ring" | |
ring_mesh: DeviceMesh | None = None | |
ulysses_mesh: DeviceMesh | None = None | |
convert_to_fp32: bool = True | |
op: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]] | None = None | |
cp_options = ContextParallelOptions() | |
def _templated_ring_attention(query, key, value): | |
rank = cp_options.ring_mesh.get_rank() | |
world_size = cp_options.ring_mesh.size() | |
if world_size == 1: | |
return cp_options.op(query, key, value) | |
next_rank = (rank + 1) % world_size | |
prev_out = prev_lse = None | |
kv_buffer = torch.cat([key.flatten(), value.flatten()]).contiguous() | |
kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=cp_options.ring_mesh.get_group()) | |
kv_buffer = kv_buffer.chunk(world_size) | |
for i in range(world_size): | |
if i > 0: | |
kv = kv_buffer[next_rank] | |
key = kv[:key.numel()].reshape_as(key) | |
value = kv[key.numel():].reshape_as(value) | |
next_rank = (next_rank + 1) % world_size | |
out, lse = cp_options.op(query, key, value) | |
if cp_options.convert_to_fp32: | |
out = out.to(torch.float32) | |
lse = lse.to(torch.float32) | |
lse = lse.unsqueeze(-1) | |
if prev_out is not None: | |
out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out) | |
lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse) | |
prev_out = out | |
prev_lse = lse | |
out = out.to(query.dtype) | |
lse = lse.squeeze(-1) | |
return out, lse | |
def _templated_ulysses_attention(query, key, value): | |
world_size = cp_options.ulysses_mesh.size() | |
group = cp_options.ulysses_mesh.get_group() | |
if world_size == 1: | |
return cp_options.op(query, key, value) | |
B, H, S_LOCAL, D = query.shape | |
H_LOCAL = H // world_size | |
query, key, value = ( | |
x.reshape(B, world_size, H_LOCAL, S_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() | |
for x in (query, key, value) | |
) | |
query, key, value = ( | |
funcol.all_to_all_single(x, None, None, group=group) | |
for x in (query, key, value) | |
) | |
query, key, value = ( | |
x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() | |
for x in (query, key, value) | |
) | |
out, lse = cp_options.op(query, key, value) | |
out = out.reshape(B, H_LOCAL, world_size, S_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() | |
lse = lse.reshape(B, H_LOCAL, world_size, S_LOCAL).permute(2, 1, 0, 3).contiguous() | |
out = funcol.all_to_all_single(out, None, None, group=group).wait() | |
lse = funcol.all_to_all_single(lse, None, None, group=group).wait() | |
out = out.flatten(0, 1).permute(1, 0, 2, 3).contiguous() | |
lse = lse.flatten(0, 1).permute(1, 0, 2).contiguous() | |
return out, lse | |
def _templated_unified_attention(query, key, value): | |
ring_size = cp_options.ring_mesh.size() | |
ulysses_size = cp_options.ulysses_mesh.size() | |
ulysses_group = cp_options.ulysses_mesh.get_group() | |
world_size = ring_size * ulysses_size | |
if world_size == 1: | |
return cp_options.op(query, key, value) | |
B, H, S_LOCAL, D = query.shape | |
H_LOCAL = H // ulysses_size | |
query, key, value = ( | |
x.reshape(B, ulysses_size, H_LOCAL, S_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() | |
for x in (query, key, value) | |
) | |
query, key, value = ( | |
funcol.all_to_all_single(x, None, None, group=ulysses_group) | |
for x in (query, key, value) | |
) | |
query, key, value = ( | |
x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() | |
for x in (query, key, value) | |
) | |
out, lse = _templated_ring_attention(query, key, value) | |
out = out.reshape(B, H_LOCAL, ulysses_size, S_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() | |
lse = lse.reshape(B, H_LOCAL, ulysses_size, S_LOCAL).permute(2, 1, 0, 3).contiguous() | |
out = funcol.all_to_all_single(out, None, None, group=ulysses_group).wait() | |
lse = funcol.all_to_all_single(lse, None, None, group=ulysses_group).wait() | |
out = out.flatten(0, 1).permute(1, 0, 2, 3).contiguous() | |
lse = lse.flatten(0, 1).permute(1, 0, 2).contiguous() | |
return out, lse | |
def torch_cudnn_attention(query, key, value): | |
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( | |
torch.ops.aten._scaled_dot_product_cudnn_attention( | |
query=query, | |
key=key, | |
value=value, | |
attn_bias=None, | |
compute_log_sumexp=True, | |
) | |
) | |
return out, lse | |
def torch_flash_attention(query, key, value): | |
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( | |
torch.ops.aten._scaled_dot_product_flash_attention( | |
query=query, | |
key=key, | |
value=value, | |
) | |
) | |
return out, lse | |
OPS = { | |
"cudnn": torch_cudnn_attention, | |
"flash": torch_flash_attention, | |
} | |
WORLD_SIZE = -1 | |
RANK = -1 | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--ring_degree", type=int, default=1) | |
parser.add_argument("--ulysses_degree", type=int, default=1) | |
parser.add_argument("--batch_size", type=int, default=1) | |
parser.add_argument("--num_heads", type=int, default=24) | |
parser.add_argument("--head_dim", type=int, default=128) | |
parser.add_argument("--seq_lens", type=int, nargs="+", default=[512, 1024, 2048, 4096, 4224, 4352, 4480, 4608, 8192]) | |
parser.add_argument( | |
"--ops", | |
type=str, | |
nargs="+", | |
choices=list(OPS.keys()), | |
default=list(OPS.keys()), | |
) | |
parser.add_argument("--seed", type=int, default=42) | |
args = parser.parse_args() | |
return args | |
def main( | |
ring_degree: int, | |
ulysses_degree: int, | |
batch_size: int, | |
num_heads: int, | |
head_dim: int, | |
seq_lens: List[int], | |
ops: List[str], | |
seed: int, | |
): | |
global cp_options, WORLD_SIZE, RANK | |
mesh_names = ["ring", "ulysses"] | |
mesh_dims = [ring_degree, ulysses_degree] | |
mesh = dist.device_mesh.init_device_mesh("cuda", mesh_dims, mesh_dim_names=mesh_names) | |
cp_options.ring_mesh = mesh["ring"] | |
cp_options.ulysses_mesh = mesh["ulysses"] | |
cp_options.convert_to_fp32 = True | |
cp_attention = None | |
num_warmups = 5 | |
num_repeats = 10 | |
device = torch.device("cuda") | |
dtype = torch.bfloat16 | |
if ring_degree > 1 and ulysses_degree > 1: | |
cp_options.mode = "unified" | |
cp_attention = _templated_unified_attention | |
elif ulysses_degree > 1: | |
cp_options.mode = "ulysses" | |
cp_attention = _templated_ulysses_attention | |
else: | |
cp_options.mode = "ring" | |
cp_attention = _templated_ring_attention | |
results = {} | |
for op_name in ops: | |
op = OPS[op_name] | |
cp_options.op = op | |
results[op_name] = {} | |
for seq_len in seq_lens: | |
shape = (batch_size, num_heads, seq_len, head_dim) | |
query = torch.randn(shape, device=device, dtype=dtype) | |
key = torch.randn(shape, device=device, dtype=dtype) | |
value = torch.randn(shape, device=device, dtype=dtype) | |
dist.broadcast(query, src=0) | |
dist.broadcast(key, src=0) | |
dist.broadcast(value, src=0) | |
dist.barrier() | |
torch.cuda.synchronize() | |
reference_out, reference_lse = torch_cudnn_attention(query, key, value) | |
query, key, value = (x.chunk(WORLD_SIZE, dim=2)[RANK].contiguous() for x in (query, key, value)) | |
for _ in range(num_warmups): | |
if WORLD_SIZE == 1: | |
out, lse = op(query, key, value) | |
else: | |
out, lse = cp_attention(query, key, value) | |
out = funcol.all_gather_tensor(out, gather_dim=2, group=mesh._flatten().get_group()) | |
lse = funcol.all_gather_tensor(lse, gather_dim=2, group=mesh._flatten().get_group()) | |
torch.cuda.synchronize() | |
diff = out - reference_out | |
absdiff = torch.abs(diff) | |
absmax = torch.max(absdiff) | |
mae = torch.mean(absdiff) | |
mse = torch.mean(diff * diff) | |
if RANK == 0: | |
print(f"op: {op_name}, seq_len: {seq_len}, absmax: {absmax:.5f}, mae: {mae:.5f}, mse: {mse:.5f}") | |
# if not torch.allclose(out, reference_out, atol=1e-2, rtol=1e-2): | |
# raise ValueError(f"Output mismatch for op: {op_name}, seq_len: {seq_len}") | |
# if not torch.allclose(lse, reference_lse, atol=1e-2, rtol=1e-2): | |
# raise ValueError(f"LSE mismatch for op: {op_name}, seq_len: {seq_len}") | |
dist.barrier() | |
start_event = torch.cuda.Event(enable_timing=True) | |
end_event = torch.cuda.Event(enable_timing=True) | |
start_event.record() | |
for _ in range(num_repeats): | |
if WORLD_SIZE == 1: | |
out, lse = op(query, key, value) | |
else: | |
out, lse = cp_attention(query, key, value) | |
end_event.record() | |
torch.cuda.synchronize() | |
dist.barrier() | |
elapsed_time = start_event.elapsed_time(end_event) / num_repeats | |
results[op_name][seq_len] = elapsed_time | |
if RANK == 0: | |
print("Benchmark results:") | |
for op_name, seq_times in results.items(): | |
print(f"\n\n===== op: {op_name} =====") | |
for seq_len, time in seq_times.items(): | |
print(f" {seq_len=}, {time:.5f} ms") | |
if __name__ == "__main__": | |
args = get_args() | |
torch.manual_seed(args.seed) | |
try: | |
dist.init_process_group(backend="nccl") | |
WORLD_SIZE = dist.get_world_size() | |
RANK = dist.get_rank() | |
torch.cuda.set_device(RANK) | |
if args.ring_degree * args.ulysses_degree != WORLD_SIZE: | |
raise ValueError( | |
f"ring_degree * ulysses_degree must equal world size, got {args.ring_degree} * {args.ulysses_degree} != {WORLD_SIZE}" | |
) | |
main( | |
ring_degree=args.ring_degree, | |
ulysses_degree=args.ulysses_degree, | |
batch_size=args.batch_size, | |
num_heads=args.num_heads, | |
head_dim=args.head_dim, | |
seq_lens=args.seq_lens, | |
ops=args.ops, | |
seed=args.seed, | |
) | |
finally: | |
dist.destroy_process_group() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment