Skip to content

Instantly share code, notes, and snippets.

@a-r-r-o-w
Created June 30, 2025 11:36
Show Gist options
  • Save a-r-r-o-w/87926a348703e55f008f259a1778e4f3 to your computer and use it in GitHub Desktop.
Save a-r-r-o-w/87926a348703e55f008f259a1778e4f3 to your computer and use it in GitHub Desktop.
sequential and templated ring/ulysses/unified attention implementation
import torch
torch.manual_seed(42)
def torch_sdpa(query, key, value):
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
torch.ops.aten._scaled_dot_product_cudnn_attention(
query=query,
key=key,
value=value,
attn_bias=None,
compute_log_sumexp=True,
)
)
return out, lse
def ring_sdpa_sequential(partial_queries, partial_keys, partial_values, *, world_size: int = 1, convert_to_fp32: bool = True):
outputs, lses = [], []
for rank in range(world_size):
query, key, value = partial_queries[rank], partial_keys[rank], partial_values[rank]
next_rank = (rank + 1) % world_size
prev_out = prev_lse = None
for i in range(world_size):
if i > 0:
key, value = partial_keys[next_rank], partial_values[next_rank]
next_rank = (next_rank + 1) % world_size
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
torch.ops.aten._scaled_dot_product_cudnn_attention(
query=query,
key=key,
value=value,
attn_bias=None,
compute_log_sumexp=True,
)
)
if convert_to_fp32:
out = out.to(torch.float32)
lse = lse.to(torch.float32)
# https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795
lse = lse.unsqueeze(-1)
if prev_out is not None:
out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
prev_out = out
prev_lse = lse
out = out.to(query.dtype)
lse = lse.squeeze(-1)
outputs.append(out)
lses.append(lse)
return outputs, lses
device = "cuda"
dtype = torch.bfloat16
world_size = 4
batch_size = 1
image_sequence_length = 4096
text_sequence_length = 512
sequence_length = image_sequence_length + text_sequence_length
num_attention_heads = 24
attention_head_dim = 128
query = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype)
key = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype)
value = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype)
partial_queries = query.chunk(world_size, dim=2)
partial_keys = key.chunk(world_size, dim=2)
partial_values = value.chunk(world_size, dim=2)
torch_sdpa_out, torch_sdpa_lse = torch_sdpa(query, key, value)
ring_sdpa_out, ring_sdpa_lse = ring_sdpa_sequential(partial_queries, partial_keys, partial_values, world_size=world_size)
all_ring_sdpa_out = torch.cat(ring_sdpa_out, dim=2)
all_ring_sdpa_lse = torch.cat(ring_sdpa_lse, dim=2)
assert torch_sdpa_out.shape == all_ring_sdpa_out.shape, "Output shapes do not match!"
assert torch_sdpa_lse.shape == all_ring_sdpa_lse.shape, "LSE shapes do not match!"
assert torch.allclose(all_ring_sdpa_out, torch_sdpa_out, atol=1e-3, rtol=1e-3), "Outputs do not match!"
assert torch.allclose(all_ring_sdpa_lse, torch_sdpa_lse, atol=1e-3, rtol=1e-3), "LSE values do not match!"
import torch
torch.manual_seed(42)
def torch_sdpa(query, key, value):
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
torch.ops.aten._scaled_dot_product_cudnn_attention(
query=query,
key=key,
value=value,
attn_bias=None,
compute_log_sumexp=True,
)
)
return out, lse
def ulysses_sdpa_sequential(partial_queries, partial_keys, partial_values, *, world_size: int = 1):
B, H, S_LOCAL, D = partial_queries[0].shape
H_LOCAL = H // world_size
outputs, lses = [], []
for partials in [partial_queries, partial_keys, partial_values]:
for rank in range(world_size):
x_local = partials[rank]
# (B, H, S // world_size, D) -> (world_size, S // world_size, B, H // world_size, D)
partials[rank] = x_local.reshape(B, world_size, H_LOCAL, S_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
x = all_to_all_single_sequential(partials, world_size)
for rank in range(world_size):
x_local = x[rank]
# (S, B, H // world_size, D) -> (B, H // world_size, S, D)
partials[rank] = x_local.permute(1, 2, 0, 3).contiguous()
for rank in range(world_size):
query_local, key_local, value_local = partial_queries[rank], partial_keys[rank], partial_values[rank]
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
torch.ops.aten._scaled_dot_product_cudnn_attention(
query=query_local,
key=key_local,
value=value_local,
attn_bias=None,
compute_log_sumexp=True,
)
)
outputs.append(out)
lses.append(lse)
for rank in range(world_size):
out_local = outputs[rank]
lse_local = lses[rank]
# (B, H // world_size, S, D) -> (B, H // world_size, world_size, S // world_size, D) -> (world_size, H // world_size, B, S // world_size, D)
outputs[rank] = out_local.reshape(B, H_LOCAL, world_size, S_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
lses[rank] = lse_local.reshape(B, H_LOCAL, world_size, S_LOCAL).permute(2, 1, 0, 3).contiguous()
outputs = all_to_all_single_sequential(outputs, world_size)
lses = all_to_all_single_sequential(lses, world_size)
for rank in range(world_size):
out_local = outputs[rank]
lse_local = lses[rank]
# (H, B, S // world_size, D) -> (B, H, S // world_size, D)
outputs[rank] = out_local.permute(1, 0, 2, 3).contiguous()
lses[rank] = lse_local.permute(1, 0, 2).contiguous()
return outputs, lses
def all_to_all_single_sequential(partials, world_size):
output_partials = []
for i in range(world_size):
received_chunks = [p[i] for p in partials]
output_partials.append(torch.cat(received_chunks, dim=0))
return output_partials
device = "cuda"
dtype = torch.bfloat16
world_size = 4
batch_size = 1
image_sequence_length = 4096
text_sequence_length = 512
sequence_length = image_sequence_length + text_sequence_length
num_attention_heads = 24
attention_head_dim = 128
query = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype)
key = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype)
value = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype)
partial_queries = list(query.chunk(world_size, dim=2))
partial_keys = list(key.chunk(world_size, dim=2))
partial_values = list(value.chunk(world_size, dim=2))
torch_sdpa_out, torch_sdpa_lse = torch_sdpa(query, key, value)
ulysses_sdpa_out, ulysses_sdpa_lse = ulysses_sdpa_sequential(partial_queries, partial_keys, partial_values, world_size=world_size)
all_ulysses_sdpa_out = torch.cat(ulysses_sdpa_out, dim=2)
all_ulysses_sdpa_lse = torch.cat(ulysses_sdpa_lse, dim=2)
assert torch_sdpa_out.shape == all_ulysses_sdpa_out.shape, "Output shapes do not match!"
assert torch_sdpa_lse.shape == all_ulysses_sdpa_lse.shape, "LSE shapes do not match!"
assert torch.allclose(all_ulysses_sdpa_out, torch_sdpa_out, atol=1e-3, rtol=1e-3), "Outputs do not match!"
assert torch.allclose(all_ulysses_sdpa_lse, torch_sdpa_lse, atol=1e-3, rtol=1e-3), "LSEs do not match!"
import torch
torch.manual_seed(42)
def torch_sdpa(query, key, value):
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
torch.ops.aten._scaled_dot_product_cudnn_attention(
query=query,
key=key,
value=value,
attn_bias=None,
compute_log_sumexp=True,
)
)
return out, lse
def ring_sdpa_sequential(partial_queries, partial_keys, partial_values, *, ring_size: int = 1, convert_to_fp32: bool = True):
outputs, lses = [], []
for rank in range(ring_size):
query, key, value = partial_queries[rank], partial_keys[rank], partial_values[rank]
next_rank = (rank + 1) % ring_size
prev_out = prev_lse = None
for i in range(ring_size):
if i > 0:
key, value = partial_keys[next_rank], partial_values[next_rank]
next_rank = (next_rank + 1) % ring_size
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
torch.ops.aten._scaled_dot_product_cudnn_attention(
query=query,
key=key,
value=value,
attn_bias=None,
compute_log_sumexp=True,
)
)
if convert_to_fp32:
out = out.to(torch.float32)
lse = lse.to(torch.float32)
# https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795
lse = lse.unsqueeze(-1)
if prev_out is not None:
out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
prev_out = out
prev_lse = lse
out = out.to(query.dtype)
lse = lse.squeeze(-1)
outputs.append(out)
lses.append(lse)
return outputs, lses
def unified_ulysses_ring_sdpa_sequential(partial_queries, partial_keys, partial_values, *, ulysses_size: int = 1, ring_size: int = 1):
B, H, S_LOCAL, D = partial_queries[0][0].shape
H_LOCAL = H // ulysses_size
outputs, lses = [], []
for partials in [partial_queries, partial_keys, partial_values]:
for ring_rank in range(ring_size):
for rank in range(ulysses_size):
x_local = partials[ring_rank][rank]
partials[ring_rank][rank] = x_local.reshape(B, ulysses_size, H_LOCAL, S_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
x = all_to_all_single_sequential(partials[ring_rank], ulysses_size)
for rank in range(ulysses_size):
x_local = x[rank]
partials[ring_rank][rank] = x_local.permute(1, 2, 0, 3).contiguous()
partial_queries = [list(x) for x in zip(*partial_queries)]
partial_keys = [list(x) for x in zip(*partial_keys)]
partial_values = [list(x) for x in zip(*partial_values)]
for rank in range(ulysses_size):
ring_outputs, ring_lses = ring_sdpa_sequential(partial_queries[rank], partial_keys[rank], partial_values[rank], ring_size=ring_size)
outputs.append(ring_outputs)
lses.append(ring_lses)
outputs = [list(x) for x in zip(*outputs)]
lses = [list(x) for x in zip(*lses)]
for ring_rank in range(ring_size):
for rank in range(ulysses_size):
outputs[ring_rank][rank] = outputs[ring_rank][rank].reshape(B, H_LOCAL, ulysses_size, S_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
lses[ring_rank][rank] = lses[ring_rank][rank].reshape(B, H_LOCAL, ulysses_size, S_LOCAL).permute(2, 1, 0, 3).contiguous()
outputs[ring_rank] = all_to_all_single_sequential(outputs[ring_rank], ulysses_size)
lses[ring_rank] = all_to_all_single_sequential(lses[ring_rank], ulysses_size)
for rank in range(ulysses_size):
outputs[ring_rank][rank] = outputs[ring_rank][rank].permute(1, 0, 2, 3).contiguous()
lses[ring_rank][rank] = lses[ring_rank][rank].permute(1, 0, 2).contiguous()
return outputs, lses
def all_to_all_single_sequential(partials, world_size):
output_partials = []
for i in range(world_size):
received_chunks = [p[i] for p in partials]
output_partials.append(torch.cat(received_chunks, dim=0))
return output_partials
device = "cuda"
dtype = torch.bfloat16
WORLD_SIZE = 8
ulysses_size = 4
ring_size = 2
assert ulysses_size * ring_size == WORLD_SIZE, "ulysses_size * ring_size must equal WORLD_SIZE"
batch_size = 1
image_sequence_length = 4096
text_sequence_length = 512
sequence_length = image_sequence_length + text_sequence_length
num_attention_heads = 24
attention_head_dim = 128
query = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype)
key = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype)
value = torch.randn(batch_size, num_attention_heads, sequence_length, attention_head_dim, device=device, dtype=dtype)
partial_queries = list(query.chunk(WORLD_SIZE, dim=2))
partial_keys = list(key.chunk(WORLD_SIZE, dim=2))
partial_values = list(value.chunk(WORLD_SIZE, dim=2))
# R=1, U=4 => [[tensor1, tensor2, tensor3, tensor4]]
# R=2, U=2 => [[tensor1, tensor2], [tensor3, tensor4]]
# R=4, U=1 => [[tensor1], [tensor2], [tensor3], [tensor4]]
partial_queries = [partial_queries[i:i + ulysses_size] for i in range(0, WORLD_SIZE, ulysses_size)]
partial_keys = [partial_keys[i:i + ulysses_size] for i in range(0, WORLD_SIZE, ulysses_size)]
partial_values = [partial_values[i:i + ulysses_size] for i in range(0, WORLD_SIZE, ulysses_size)]
torch_sdpa_out, torch_sdpa_lse = torch_sdpa(query, key, value)
unified_sdpa_out, unified_sdpa_lse = unified_ulysses_ring_sdpa_sequential(partial_queries, partial_keys, partial_values, ulysses_size=ulysses_size, ring_size=ring_size)
all_unified_sdpa_out = torch.cat([torch.cat(out, dim=2) for out in unified_sdpa_out], dim=2)
all_unified_sdpa_lse = torch.cat([torch.cat(lse, dim=2) for lse in unified_sdpa_lse], dim=2)
assert torch_sdpa_out.shape == all_unified_sdpa_out.shape, "Output shapes do not match!"
assert torch_sdpa_lse.shape == all_unified_sdpa_lse.shape, "LSE shapes do not match!"
assert torch.allclose(all_unified_sdpa_out, torch_sdpa_out, atol=1e-3, rtol=1e-3), "Outputs do not match!"
assert torch.allclose(all_unified_sdpa_lse, torch_sdpa_lse, atol=1e-3, rtol=1e-3), "LSEs do not match!"
import argparse
from dataclasses import dataclass
from typing import Callable, Literal, List
import torch
import torch.distributed as dist
import torch.distributed._functional_collectives as funcol
from torch.distributed import DeviceMesh
@dataclass
class ContextParallelOptions:
mode: Literal["ring", "ulysses", "unified"] = "ring"
ring_mesh: DeviceMesh | None = None
ulysses_mesh: DeviceMesh | None = None
convert_to_fp32: bool = True
op: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]] | None = None
cp_options = ContextParallelOptions()
def _templated_ring_attention(query, key, value):
rank = cp_options.ring_mesh.get_rank()
world_size = cp_options.ring_mesh.size()
if world_size == 1:
return cp_options.op(query, key, value)
next_rank = (rank + 1) % world_size
prev_out = prev_lse = None
kv_buffer = torch.cat([key.flatten(), value.flatten()]).contiguous()
kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=cp_options.ring_mesh.get_group())
kv_buffer = kv_buffer.chunk(world_size)
for i in range(world_size):
if i > 0:
kv = kv_buffer[next_rank]
key = kv[:key.numel()].reshape_as(key)
value = kv[key.numel():].reshape_as(value)
next_rank = (next_rank + 1) % world_size
out, lse = cp_options.op(query, key, value)
if cp_options.convert_to_fp32:
out = out.to(torch.float32)
lse = lse.to(torch.float32)
lse = lse.unsqueeze(-1)
if prev_out is not None:
out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
prev_out = out
prev_lse = lse
out = out.to(query.dtype)
lse = lse.squeeze(-1)
return out, lse
def _templated_ulysses_attention(query, key, value):
world_size = cp_options.ulysses_mesh.size()
group = cp_options.ulysses_mesh.get_group()
if world_size == 1:
return cp_options.op(query, key, value)
B, H, S_LOCAL, D = query.shape
H_LOCAL = H // world_size
query, key, value = (
x.reshape(B, world_size, H_LOCAL, S_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
for x in (query, key, value)
)
query, key, value = (
funcol.all_to_all_single(x, None, None, group=group)
for x in (query, key, value)
)
query, key, value = (
x.flatten(0, 1).permute(1, 2, 0, 3).contiguous()
for x in (query, key, value)
)
out, lse = cp_options.op(query, key, value)
out = out.reshape(B, H_LOCAL, world_size, S_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
lse = lse.reshape(B, H_LOCAL, world_size, S_LOCAL).permute(2, 1, 0, 3).contiguous()
out = funcol.all_to_all_single(out, None, None, group=group).wait()
lse = funcol.all_to_all_single(lse, None, None, group=group).wait()
out = out.flatten(0, 1).permute(1, 0, 2, 3).contiguous()
lse = lse.flatten(0, 1).permute(1, 0, 2).contiguous()
return out, lse
def _templated_unified_attention(query, key, value):
ring_size = cp_options.ring_mesh.size()
ulysses_size = cp_options.ulysses_mesh.size()
ulysses_group = cp_options.ulysses_mesh.get_group()
world_size = ring_size * ulysses_size
if world_size == 1:
return cp_options.op(query, key, value)
B, H, S_LOCAL, D = query.shape
H_LOCAL = H // ulysses_size
query, key, value = (
x.reshape(B, ulysses_size, H_LOCAL, S_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
for x in (query, key, value)
)
query, key, value = (
funcol.all_to_all_single(x, None, None, group=ulysses_group)
for x in (query, key, value)
)
query, key, value = (
x.flatten(0, 1).permute(1, 2, 0, 3).contiguous()
for x in (query, key, value)
)
out, lse = _templated_ring_attention(query, key, value)
out = out.reshape(B, H_LOCAL, ulysses_size, S_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
lse = lse.reshape(B, H_LOCAL, ulysses_size, S_LOCAL).permute(2, 1, 0, 3).contiguous()
out = funcol.all_to_all_single(out, None, None, group=ulysses_group).wait()
lse = funcol.all_to_all_single(lse, None, None, group=ulysses_group).wait()
out = out.flatten(0, 1).permute(1, 0, 2, 3).contiguous()
lse = lse.flatten(0, 1).permute(1, 0, 2).contiguous()
return out, lse
def torch_cudnn_attention(query, key, value):
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
torch.ops.aten._scaled_dot_product_cudnn_attention(
query=query,
key=key,
value=value,
attn_bias=None,
compute_log_sumexp=True,
)
)
return out, lse
def torch_flash_attention(query, key, value):
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
torch.ops.aten._scaled_dot_product_flash_attention(
query=query,
key=key,
value=value,
)
)
return out, lse
OPS = {
"cudnn": torch_cudnn_attention,
"flash": torch_flash_attention,
}
WORLD_SIZE = -1
RANK = -1
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--ring_degree", type=int, default=1)
parser.add_argument("--ulysses_degree", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--num_heads", type=int, default=24)
parser.add_argument("--head_dim", type=int, default=128)
parser.add_argument("--seq_lens", type=int, nargs="+", default=[512, 1024, 2048, 4096, 4224, 4352, 4480, 4608, 8192])
parser.add_argument(
"--ops",
type=str,
nargs="+",
choices=list(OPS.keys()),
default=list(OPS.keys()),
)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
return args
def main(
ring_degree: int,
ulysses_degree: int,
batch_size: int,
num_heads: int,
head_dim: int,
seq_lens: List[int],
ops: List[str],
seed: int,
):
global cp_options, WORLD_SIZE, RANK
mesh_names = ["ring", "ulysses"]
mesh_dims = [ring_degree, ulysses_degree]
mesh = dist.device_mesh.init_device_mesh("cuda", mesh_dims, mesh_dim_names=mesh_names)
cp_options.ring_mesh = mesh["ring"]
cp_options.ulysses_mesh = mesh["ulysses"]
cp_options.convert_to_fp32 = True
cp_attention = None
num_warmups = 5
num_repeats = 10
device = torch.device("cuda")
dtype = torch.bfloat16
if ring_degree > 1 and ulysses_degree > 1:
cp_options.mode = "unified"
cp_attention = _templated_unified_attention
elif ulysses_degree > 1:
cp_options.mode = "ulysses"
cp_attention = _templated_ulysses_attention
else:
cp_options.mode = "ring"
cp_attention = _templated_ring_attention
results = {}
for op_name in ops:
op = OPS[op_name]
cp_options.op = op
results[op_name] = {}
for seq_len in seq_lens:
shape = (batch_size, num_heads, seq_len, head_dim)
query = torch.randn(shape, device=device, dtype=dtype)
key = torch.randn(shape, device=device, dtype=dtype)
value = torch.randn(shape, device=device, dtype=dtype)
dist.broadcast(query, src=0)
dist.broadcast(key, src=0)
dist.broadcast(value, src=0)
dist.barrier()
torch.cuda.synchronize()
reference_out, reference_lse = torch_cudnn_attention(query, key, value)
query, key, value = (x.chunk(WORLD_SIZE, dim=2)[RANK].contiguous() for x in (query, key, value))
for _ in range(num_warmups):
if WORLD_SIZE == 1:
out, lse = op(query, key, value)
else:
out, lse = cp_attention(query, key, value)
out = funcol.all_gather_tensor(out, gather_dim=2, group=mesh._flatten().get_group())
lse = funcol.all_gather_tensor(lse, gather_dim=2, group=mesh._flatten().get_group())
torch.cuda.synchronize()
diff = out - reference_out
absdiff = torch.abs(diff)
absmax = torch.max(absdiff)
mae = torch.mean(absdiff)
mse = torch.mean(diff * diff)
if RANK == 0:
print(f"op: {op_name}, seq_len: {seq_len}, absmax: {absmax:.5f}, mae: {mae:.5f}, mse: {mse:.5f}")
# if not torch.allclose(out, reference_out, atol=1e-2, rtol=1e-2):
# raise ValueError(f"Output mismatch for op: {op_name}, seq_len: {seq_len}")
# if not torch.allclose(lse, reference_lse, atol=1e-2, rtol=1e-2):
# raise ValueError(f"LSE mismatch for op: {op_name}, seq_len: {seq_len}")
dist.barrier()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(num_repeats):
if WORLD_SIZE == 1:
out, lse = op(query, key, value)
else:
out, lse = cp_attention(query, key, value)
end_event.record()
torch.cuda.synchronize()
dist.barrier()
elapsed_time = start_event.elapsed_time(end_event) / num_repeats
results[op_name][seq_len] = elapsed_time
if RANK == 0:
print("Benchmark results:")
for op_name, seq_times in results.items():
print(f"\n\n===== op: {op_name} =====")
for seq_len, time in seq_times.items():
print(f" {seq_len=}, {time:.5f} ms")
if __name__ == "__main__":
args = get_args()
torch.manual_seed(args.seed)
try:
dist.init_process_group(backend="nccl")
WORLD_SIZE = dist.get_world_size()
RANK = dist.get_rank()
torch.cuda.set_device(RANK)
if args.ring_degree * args.ulysses_degree != WORLD_SIZE:
raise ValueError(
f"ring_degree * ulysses_degree must equal world size, got {args.ring_degree} * {args.ulysses_degree} != {WORLD_SIZE}"
)
main(
ring_degree=args.ring_degree,
ulysses_degree=args.ulysses_degree,
batch_size=args.batch_size,
num_heads=args.num_heads,
head_dim=args.head_dim,
seq_lens=args.seq_lens,
ops=args.ops,
seed=args.seed,
)
finally:
dist.destroy_process_group()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment