Skip to content

Instantly share code, notes, and snippets.

@epwalsh
Last active September 9, 2025 00:59
Show Gist options
  • Save epwalsh/8fbde5374638b62f49743a219831dc7c to your computer and use it in GitHub Desktop.
Save epwalsh/8fbde5374638b62f49743a219831dc7c to your computer and use it in GitHub Desktop.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX/TE custom ops for attention"""
import operator
import os
import warnings
from dataclasses import dataclass, replace
from functools import partial, reduce
from typing import Optional, Tuple
from packaging import version
import jax
import jax.numpy as jnp
from jax import dtypes, lax
from jax.sharding import PartitionSpec, NamedSharding
from jax.experimental.custom_partitioning import SdyShardingRule
import transformer_engine_jax
from transformer_engine_jax import NVTE_Fused_Attn_Backend
from transformer_engine.jax.attention import (
AttnBiasType,
AttnMaskType,
QKVLayout,
QKVFormat,
CPStrategy,
SequenceDescriptor,
)
from .base import BasePrimitive, register_primitive
from .misc import (
check_valid_batch_dims,
jax_dtype_to_te_dtype,
te_dtype_to_jax_dtype,
get_padded_spec,
get_cudnn_version,
)
from ..sharding import (
global_mesh_resource,
lax_paral_op,
all_reduce_sum_along_dp_fsdp,
get_mesh_axis_size,
get_mesh_axis_rank,
get_mesh_axis_rank_host,
get_all_mesh_axes,
num_of_devices,
with_sharding_constraint,
)
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
__all__ = [
"FusedAttnHelper",
"fused_attn_fwd",
"fused_attn_bwd",
]
@partial(
jax.tree_util.register_dataclass,
data_fields=[],
meta_fields=[
"attn_bias_type",
"attn_mask_type",
"qkv_layout",
"scaling_factor",
"dropout_probability",
"is_training",
"max_segments_per_seq",
"window_size",
"context_parallel_load_balanced",
"cp_axis",
"cp_striped_window_size",
],
)
@dataclass(frozen=True)
class _FusedAttnConfig:
"""
Passes static configuration properties of fused attention.
"""
attn_bias_type: AttnBiasType
attn_mask_type: AttnMaskType
qkv_layout: QKVLayout
scaling_factor: float
dropout_probability: float
is_training: bool
max_segments_per_seq: int
window_size: Tuple[int, int]
context_parallel_load_balanced: bool
cp_axis: str
cp_striped_window_size: Tuple[int, int] # Only for CP + Ring + THD + SWA
@dataclass(frozen=True)
class FusedAttnHelper:
"""
Helper for the fused attention backend
"""
is_training: bool
q_dtype: jnp.dtype
kv_dtype: jnp.dtype
qkv_layout: QKVLayout
attn_bias_type: AttnBiasType
attn_mask_type: AttnMaskType
dropout_probability: float
q_num_heads: int
kv_num_heads: int
q_max_seqlen: int
kv_max_seqlen: int
head_dim_qk: int
head_dim_v: int
window_size: Tuple[int, int]
def is_fused_attn_kernel_available(self):
"""Check if there is available fused attention kernel"""
return self.get_fused_attn_backend() != NVTE_Fused_Attn_Backend.NVTE_No_Backend
def get_fused_attn_backend(self):
"""Get the fused attention kernel backend"""
return transformer_engine_jax.get_fused_attn_backend(
self.is_training,
jax_dtype_to_te_dtype(self.q_dtype),
jax_dtype_to_te_dtype(self.kv_dtype),
self.qkv_layout.value,
self.attn_bias_type.value,
self.attn_mask_type.value,
self.dropout_probability,
self.q_num_heads,
self.kv_num_heads,
self.q_max_seqlen,
self.kv_max_seqlen,
self.head_dim_qk,
self.head_dim_v,
self.window_size[0],
self.window_size[1],
)
@staticmethod
def is_non_deterministic_allowed():
"""Check if non-deterministic kernels are allowed"""
return bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
@staticmethod
def parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout):
"""Parse qkv aval"""
if qkv_layout.get_qkv_format() == QKVFormat.SBHD:
raise NotImplementedError
if qkv_layout.is_qkvpacked():
*q_batch_shape, q_max_seqlen, nqkv, attn_heads, q_head_dim = q_aval.shape
kv_batch_shape = q_batch_shape
kv_max_seqlen = q_max_seqlen
num_gqa_groups = attn_heads
v_head_dim = q_head_dim
assert nqkv == 3
elif qkv_layout.is_kvpacked():
*q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
*kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, v_head_dim = k_aval.shape
assert q_batch_shape == kv_batch_shape
assert q_head_dim == v_head_dim
assert nkv == 2
elif qkv_layout.is_separate():
*q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape
*k_batch_shape, k_max_seqlen, k_num_gqa_groups, k_head_dim = k_aval.shape
*v_batch_shape, v_max_seqlen, v_num_gqa_groups, v_head_dim = v_aval.shape
assert (
q_head_dim == k_head_dim
), f"Mismatched q_head_dim: {q_head_dim} and k_head_dim: {k_head_dim}"
assert (
k_max_seqlen == v_max_seqlen
), f"Mismatched k_max_seqlen: {k_max_seqlen} and v_max_seqlen: {v_max_seqlen}"
kv_max_seqlen = k_max_seqlen
assert q_batch_shape == k_batch_shape == v_batch_shape, (
f"Mismatched qkv batch size for q_batch_shape: {q_batch_shape}, k_batch_shape:"
f" {k_batch_shape} and v_batch_shape: {v_batch_shape}"
)
assert k_num_gqa_groups == v_num_gqa_groups, (
f"Mismatched k_num_gqa_groups: {k_num_gqa_groups} and v_num_gqa_groups:"
f" {v_num_gqa_groups}"
)
num_gqa_groups = k_num_gqa_groups
else:
raise ValueError(f"Unexpected {qkv_layout=}")
assert q_aval.dtype == k_aval.dtype == v_aval.dtype, (
f"Mismatched data types for q_aval: {q_aval.dtype}, k_aval: {k_aval.dtype}, v_aval:"
f" {v_aval.dtype}"
)
return (
q_batch_shape,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
q_head_dim,
v_head_dim,
)
@dataclass(frozen=True)
class _FusedAttnRNGStateChecker:
"""
Checker for guarding the fused attention rng state.
The fused attention backend requires a 64 bits seed and a 64 bits offset.
However, JAX doesn't enable 64 bits by default,
so we have to emulate seed as two 32 bits array.
The offset calculation is maintained in the backend.
"""
rng_state_dtype: jnp.dtype = jnp.uint32
# (seed,) with internal dtype int64
seed_size: int = 2
# (seed, offset) with internal dtype int64
rng_state_size: int = 2 * 2
def check_seed(self, seed, dropout_probability, is_training):
"""
Check the seed and convert the data type of seed if possible.
"""
# Jax can't bind None, create a dummy tensor for None
if seed is None:
dropout_enabled = dropout_probability > 0 and is_training
assert not dropout_enabled, "seed is not allowed to be None when dropout is enabled."
seed = jnp.zeros(2, dtype=self.rng_state_dtype)
seed = jnp.repeat(seed, num_of_devices())
if seed.dtype != self.rng_state_dtype:
warnings.warn(
f"Requested {seed.dtype=} is not available, and will be "
f"casted to dtype {self.rng_state_dtype}. "
"Please use threefry/rbg/unsafe_rbg PRNG implementations to remove this warning."
)
seed = seed.astype(self.rng_state_dtype)
assert seed.dtype == self.rng_state_dtype
# Backend takes an int64_t seed, so only the first two u32 elements are taken
assert seed.size >= self.seed_size
return seed
def generate_cu_seqlen(actual_seqlen):
"""
Generating cumsum seqlen for a batch
"""
actual_seqlen = jnp.where(actual_seqlen < 0, 0, actual_seqlen)
cu_seqlen = jnp.cumulative_sum(actual_seqlen, include_initial=True)
return cu_seqlen
class FusedAttnFwdPrimitive(BasePrimitive):
"""
Fused Attention Forward Primitive
"""
name = "te_fused_attn_forward_ffi"
multiple_results = True
impl_static_args = (13,)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
q_aval,
k_aval,
v_aval,
bias_aval,
seed_aval,
q_seqlen_or_cu_seqlen_aval,
kv_seqlen_or_cu_seqlen_aval,
_q_seq_offsets,
_k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
*,
config: _FusedAttnConfig,
):
"""
Fused attention fwd abstract
"""
q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
k_dtype = dtypes.canonicalize_dtype(k_aval.dtype)
v_dtype = dtypes.canonicalize_dtype(v_aval.dtype)
bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
assert (
q_dtype == k_dtype == v_dtype == bias_dtype
), f"q_dtype={q_dtype}, k_dtype={k_dtype}, v_dtype={v_dtype}, bias_dtype={bias_dtype}"
assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype, (
f"q_seqlen_or_cu_seqlen_aval={q_seqlen_or_cu_seqlen_aval},"
f" kv_seqlen_or_cu_seqlen_aval={kv_seqlen_or_cu_seqlen_aval}"
)
(
batch_shape,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
q_head_dim,
v_head_dim,
) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
output_shape = (*batch_shape, q_max_seqlen, attn_heads, v_head_dim)
out_aval = q_aval.update(shape=output_shape, dtype=q_dtype)
# backend determines the softmax buffer shape/dtype
backend = FusedAttnHelper(
config.is_training,
q_dtype,
k_dtype,
config.qkv_layout,
config.attn_bias_type,
config.attn_mask_type,
config.dropout_probability,
attn_heads,
num_gqa_groups,
q_max_seqlen,
kv_max_seqlen,
q_head_dim,
v_head_dim,
config.window_size,
).get_fused_attn_backend()
if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen:
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen)
softmax_dtype = q_dtype
elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
# cuDNN 9.6 reduces the required softmax shape
if get_cudnn_version() >= (9, 6, 0):
if config.qkv_layout.is_thd():
softmax_shape = (*batch_shape, q_max_seqlen, attn_heads, 1)
else:
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1)
else:
softmax_shape = (
*batch_shape,
attn_heads,
q_max_seqlen,
config.max_segments_per_seq,
)
softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
else:
raise ValueError(f"Unsupported {backend=}")
softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype)
# JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with
# 32-bit unsigned int to get the buffer size we need in the C++ kernel
checker = _FusedAttnRNGStateChecker()
seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype)
assert seed_dtype == checker.rng_state_dtype
rng_state_shape = (seed_aval.shape[0], checker.rng_state_size)
rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype)
if config.attn_bias_type == AttnBiasType.NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
# do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to
# prepare for the active fused-attn backend
input_batch = reduce(operator.mul, batch_shape)
wkspace_info = transformer_engine_jax.get_fused_attn_fwd_workspace_sizes(
input_batch,
bias_batch,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
bias_heads,
q_head_dim,
v_head_dim,
config.scaling_factor,
config.dropout_probability,
config.attn_bias_type.value,
config.attn_mask_type.value,
config.qkv_layout.value,
jax_dtype_to_te_dtype(q_aval.dtype),
config.is_training,
config.max_segments_per_seq,
config.window_size[0],
config.window_size[1],
)
wkspace_aval = q_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
)
return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
Fused attention fwd outer primitive abstract
"""
out_aval, softmax_aux_aval, rng_state_aval, _ = FusedAttnFwdPrimitive.abstract(
*args, **kwargs
)
return out_aval, softmax_aux_aval, rng_state_aval
@staticmethod
def lowering(
ctx,
q,
k,
v,
bias,
seed,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
*,
config: _FusedAttnConfig,
):
"""
Fused attention fwd lowering rules
"""
q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in
(
batch_shape,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
q_head_dim,
v_head_dim,
) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
input_batch = reduce(operator.mul, batch_shape)
if config.attn_bias_type == AttnBiasType.NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
if config.cp_striped_window_size is not None:
window_size_left = config.cp_striped_window_size[0]
window_size_right = config.cp_striped_window_size[1]
else:
window_size_left = config.window_size[0]
window_size_right = config.window_size[1]
return ffi.ffi_lowering(FusedAttnFwdPrimitive.name)(
ctx,
q,
k,
v,
bias,
seed,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos, # ffi_lowering needs number of parameters meets primitive.lowering
input_batch=input_batch,
bias_batch=bias_batch,
q_max_seqlen=q_max_seqlen,
kv_max_seqlen=kv_max_seqlen,
attn_heads=attn_heads,
num_gqa_groups=num_gqa_groups,
bias_heads=bias_heads,
qk_head_dim=q_head_dim,
v_head_dim=v_head_dim,
max_segments_per_seq=config.max_segments_per_seq,
scaling_factor=float(config.scaling_factor),
dropout_probability=float(config.dropout_probability),
bias_type=int(config.attn_bias_type.value),
mask_type=int(config.attn_mask_type.value),
qkv_layout=int(config.qkv_layout.value),
is_training=config.is_training,
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=window_size_left,
window_size_right=window_size_right,
)
@staticmethod
def impl(
q,
k,
v,
bias,
seed,
q_seqlen,
kv_seqlen,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
config: _FusedAttnConfig,
):
assert FusedAttnFwdPrimitive.inner_primitive is not None
sequence_descriptor = SequenceDescriptor(
seqlens=(q_seqlen, kv_seqlen),
seq_offsets=(q_seq_offsets, k_seq_offsets),
segment_ids=(_q_segment_ids, _kv_segment_ids),
segment_pos=(_q_segment_pos, _kv_segment_pos),
)
(q_seqlen, kv_seqlen), (q_seq_offsets, k_seq_offsets) = (
sequence_descriptor.get_seqlens_and_offsets(
config.attn_mask_type,
config.qkv_layout,
config.window_size,
config.max_segments_per_seq,
)
)
if config.qkv_layout.is_thd():
def _fix_len_take(x, condition, fill_value=-1):
x_shape = x.shape
x = x.flatten()
size = x.size
indices = jnp.nonzero(condition.flatten(), size=size, fill_value=size)[0]
y = jnp.take(x, indices, fill_value=fill_value)
return jnp.reshape(y, x_shape)
def convert_to_2d(offsets, batch, max_seqlen):
offsets_2d = jnp.where(
offsets >= 0,
offsets + (jnp.arange(batch) * max_seqlen)[..., jnp.newaxis],
offsets,
)
return offsets_2d
batch, q_max_seqlen, kv_max_seqlen, *_ = FusedAttnHelper.parse_qkv_aval(
q, k, v, config.qkv_layout
)
assert len(batch) == 1, f"Expected len(batch) == 1, but got {len(batch)=}"
kv_batch = q_batch = batch[0]
# Gather valid q_seqlen, which is greater than 0
# cuDNN version < 9.3.0:
# [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, -1, -1, -1, -1]]
# cuDNN version >= 9.3.0, which supports act_seqlen = 0
# [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, 0, 0, 0, 0]]
if get_cudnn_version() >= (9, 3, 0):
fill_value = 0
else:
fill_value = -1
q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0, fill_value=fill_value)
kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0, fill_value=fill_value)
# Flatten the offset calculation
# max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]]
q_seq_offsets = convert_to_2d(q_seq_offsets, q_batch, q_max_seqlen)
k_seq_offsets = convert_to_2d(k_seq_offsets, kv_batch, kv_max_seqlen)
# Gather valid q_seq_offsets, which is greater and equal to 0
# [[0, 3, 5, -1], [8, 11, 13, -1]] -> [[0, 3, 5, 8], [11, 13, -1, -1]]
# And set the unused position to max size (batch * max_seqlen)
# [[0, 3, 5, 8], [11, 13, -1, -1]] -> [[0, 3, 5, 8], [11, 13, b*s, b*s]]
q_seq_offsets = _fix_len_take(
q_seq_offsets, q_seq_offsets >= 0, fill_value=q_batch * q_max_seqlen
)
k_seq_offsets = _fix_len_take(
k_seq_offsets, k_seq_offsets >= 0, fill_value=kv_batch * kv_max_seqlen
)
q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten())
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten())
output, softmax_aux, rng_state, _ = FusedAttnFwdPrimitive.inner_primitive.bind(
q,
k,
v,
bias,
seed,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
config=config,
)
return output, softmax_aux, rng_state
@staticmethod
def batcher(batched_args, batch_dims, *, config):
check_valid_batch_dims(batch_dims)
assert FusedAttnFwdPrimitive.outer_primitive is not None
q_bdim, _, _, _, seed_bdim, *_ = batch_dims
out_bdims = q_bdim, q_bdim, seed_bdim
return (
FusedAttnFwdPrimitive.outer_primitive.bind(*batched_args, config=config),
out_bdims,
)
@staticmethod
def infer_sharding_from_operands(config, mesh, arg_infos, result_infos):
del result_infos
q_spec = get_padded_spec(arg_infos[0])
# when supported softmax_aux shape is (b, s, h, 1) for thd on cudnn 9.6+
# otherwise softmax_aux shape is (b, h, s, 1) or (b, h, s, max_segments)
is_packed_softmax = get_cudnn_version() >= (9, 6, 0) and config.qkv_layout.is_thd()
if config.qkv_layout.is_qkvpacked():
# q_spec = (...batch, q_seqlen, 3, head, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec[:-3], *q_spec[-2:]))
if not is_packed_softmax:
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-4], q_spec[-2], q_spec[-4], None)
)
else:
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-4], q_spec[-4], q_spec[-2], None)
)
elif config.qkv_layout.is_kvpacked():
# q_spec = (...batch, q_seqlen, head, hidden)
# k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
if not is_packed_softmax:
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None)
)
else:
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-3], q_spec[-3], q_spec[-2], None)
)
elif config.qkv_layout.is_separate():
# q_spec = (...batch, q_seqlen, head, hidden)
# k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden)
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
if not is_packed_softmax:
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None)
)
else:
softmax_aux_sharding = NamedSharding(
mesh, PartitionSpec(*q_spec[:-3], q_spec[-3], q_spec[-2], None)
)
else:
raise ValueError(f"Unsupported {config.qkv_layout=}")
rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None))
return (out_sharding, softmax_aux_sharding, rng_state_sharding)
@staticmethod
def partition(config, mesh, arg_infos, result_infos):
out_sharding = result_infos[0].sharding
softmax_aux_sharding = result_infos[1].sharding
rng_state_sharding = seed_sharding = NamedSharding(
mesh, PartitionSpec(get_all_mesh_axes(), None)
)
arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[4] = seed_sharding
arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4]
arg_shardings = tuple(arg_shardings)
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
impl = partial(FusedAttnFwdPrimitive.impl, config=config)
return mesh, impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(config, mesh, value_types, result_types):
del mesh, result_types
# Keep in sync with `infer_sharding_from_operands`.
# We only need the first input. Fill up the rest with placeholders.
input_spec = [(f"…{x}",) for x in range(len(value_types))]
# The RNG state sharding cannot be expressed as a Shardy rule. We use with_sharding_constraint
# instead. This has to happen outside of the primitive, see `fused_attn_fwd`.
rng_sharding = (f"…{len(value_types)}",)
if config.qkv_layout.is_qkvpacked():
input_spec[0] = ("…0", "seqlen", "three", "head", "hidden")
elif config.qkv_layout.is_kvpacked() or config.qkv_layout.is_separate():
input_spec[0] = ("…0", "seqlen", "head", "hidden")
else:
raise ValueError(f"Unsupported {config.qkv_layout=}")
is_packed_softmax = get_cudnn_version() >= (9, 6, 0) and config.qkv_layout.is_thd()
out_sharding = ("…0", "seqlen", "head", "hidden")
if is_packed_softmax:
softmax_aux_sharding = ("…0", "seqlen", "head", "i")
else:
softmax_aux_sharding = ("…0", "head", "seqlen", "i")
return SdyShardingRule(
tuple(input_spec), (out_sharding, softmax_aux_sharding, rng_sharding)
)
register_primitive(FusedAttnFwdPrimitive)
class FusedAttnBwdPrimitive(BasePrimitive):
"""
Fused Attention Backward Primitive
"""
name = "te_fused_attn_backward_ffi"
multiple_results = True
impl_static_args = (16,)
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
q_aval,
k_aval,
v_aval,
bias_aval,
softmax_aux_aval,
rng_state_aval,
output_aval,
doutput_aval,
q_seqlen_or_cu_seqlen_aval,
kv_seqlen_or_cu_seqlen_aval,
_q_seq_offsets,
_k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
*,
config,
):
"""
Fused attention bwd abstract
"""
del softmax_aux_aval, rng_state_aval, output_aval
q_dtype = dtypes.canonicalize_dtype(q_aval.dtype)
k_dtype = dtypes.canonicalize_dtype(k_aval.dtype)
v_dtype = dtypes.canonicalize_dtype(v_aval.dtype)
bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype)
doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype)
assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype
assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype
(
batch_shape,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
qk_head_dim,
v_head_dim,
) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
if config.attn_bias_type == AttnBiasType.NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
deterministic = not FusedAttnHelper.is_non_deterministic_allowed()
input_batch = reduce(operator.mul, batch_shape)
wkspace_shape, wkspace_dtype = transformer_engine_jax.get_fused_attn_bwd_workspace_sizes(
input_batch,
bias_batch,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
bias_heads,
qk_head_dim,
v_head_dim,
config.scaling_factor,
config.dropout_probability,
config.attn_bias_type.value,
config.attn_mask_type.value,
config.qkv_layout.value,
jax_dtype_to_te_dtype(q_aval.dtype),
config.is_training,
deterministic,
config.max_segments_per_seq,
config.window_size[0],
config.window_size[1],
)
dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype)
dk_aval = k_aval.update(shape=k_aval.shape, dtype=k_dtype)
dv_aval = v_aval.update(shape=v_aval.shape, dtype=v_dtype)
dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype)
wkspace_aval = q_aval.update(
shape=wkspace_shape, dtype=te_dtype_to_jax_dtype(wkspace_dtype)
)
return dq_aval, dk_aval, dv_aval, dbias_aval, wkspace_aval
@staticmethod
def outer_abstract(*args, **kwargs):
"""
Fused attention fwd outer primitive abstract
"""
dq_aval, dk_aval, dv_aval, dbias_aval, _ = FusedAttnBwdPrimitive.abstract(*args, **kwargs)
return dq_aval, dk_aval, dv_aval, dbias_aval
@staticmethod
def lowering(
ctx,
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
q_segment_ids,
kv_segment_ids,
q_segment_pos,
kv_segment_pos,
*,
config,
):
"""
Fused attention bwd lowering rules
"""
q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in
(
batch_shape,
q_max_seqlen,
kv_max_seqlen,
attn_heads,
num_gqa_groups,
qk_head_dim,
v_head_dim,
) = FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout)
input_batch = reduce(operator.mul, batch_shape)
if config.attn_bias_type == AttnBiasType.NO_BIAS:
bias_batch = bias_heads = 0
else:
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)
if config.cp_striped_window_size is not None:
window_size_left = config.cp_striped_window_size[0]
window_size_right = config.cp_striped_window_size[1]
else:
window_size_left = config.window_size[0]
window_size_right = config.window_size[1]
return ffi.ffi_lowering(FusedAttnBwdPrimitive.name)(
ctx,
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
q_segment_ids,
kv_segment_ids,
q_segment_pos,
kv_segment_pos, # ffi_lowering needs number of parameters meets primitive.lowering
input_batch=input_batch,
bias_batch=bias_batch,
q_max_seqlen=q_max_seqlen,
kv_max_seqlen=kv_max_seqlen,
attn_heads=attn_heads,
num_gqa_groups=num_gqa_groups,
bias_heads=bias_heads,
qk_head_dim=qk_head_dim,
v_head_dim=v_head_dim,
max_segments_per_seq=config.max_segments_per_seq,
scaling_factor=float(config.scaling_factor),
dropout_probability=float(config.dropout_probability),
bias_type=int(config.attn_bias_type.value),
mask_type=int(config.attn_mask_type.value),
qkv_layout=int(config.qkv_layout.value),
is_training=config.is_training,
deterministic=not FusedAttnHelper.is_non_deterministic_allowed(),
window_size_left=window_size_left,
window_size_right=window_size_right,
)
@staticmethod
def impl(
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_seqlen,
kv_seqlen,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
config,
):
assert FusedAttnBwdPrimitive.inner_primitive is not None
sequence_descriptor = SequenceDescriptor(
seqlens=(q_seqlen, kv_seqlen),
seq_offsets=(q_seq_offsets, k_seq_offsets),
segment_ids=(_q_segment_ids, _kv_segment_ids),
segment_pos=(_q_segment_pos, _kv_segment_pos),
)
(q_seqlen, kv_seqlen), (q_seq_offsets, k_seq_offsets) = (
sequence_descriptor.get_seqlens_and_offsets(
config.attn_mask_type,
config.qkv_layout,
config.window_size,
config.max_segments_per_seq,
)
)
if config.qkv_layout.is_thd():
def _fix_len_take(x, condition, fill_value=-1):
x_shape = x.shape
x = x.flatten()
size = x.size
indices = jnp.nonzero(condition.flatten(), size=size, fill_value=size)[0]
# TODO(rewang): try indices_are_sorted
y = jnp.take(x, indices, fill_value=fill_value)
return jnp.reshape(y, x_shape)
def convert_to_2d(offsets, batch, max_seqlen):
offsets_2d = jnp.where(
offsets >= 0,
offsets + (jnp.arange(batch) * max_seqlen)[..., jnp.newaxis],
offsets,
)
return offsets_2d
batch, q_max_seqlen, kv_max_seqlen, *_ = FusedAttnHelper.parse_qkv_aval(
q, k, v, config.qkv_layout
)
assert len(batch) == 1
kv_batch = q_batch = batch[0]
# Gather valid q_seqlen, which is greater than 0
# cuDNN version < 9.3.0:
# [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, -1, -1, -1, -1]]
# cuDNN version >= 9.3.0, which supports act_seqlen = 0
# [[3, 5, 7, -1, -1], [2, 4, 6, -1, -1]] -> [[3, 5, 7, 2, 4], [6, 0, 0, 0, 0]]
if get_cudnn_version() >= (9, 3, 0):
fill_value = 0
else:
fill_value = -1
q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0, fill_value=fill_value)
kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0, fill_value=fill_value)
# Flatten the offset calculation
# max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]]
q_seq_offsets = convert_to_2d(q_seq_offsets, q_batch, q_max_seqlen)
k_seq_offsets = convert_to_2d(k_seq_offsets, kv_batch, kv_max_seqlen)
# Gather valid q_seq_offsets, which is greater and equal to 0
# [[0, 3, 5, -1], [8, 11, 13, -1]] -> [[0, 3, 5, 8], [11, 13, -1, -1]]
# And set the unused position to max size (batch * max_seqlen)
# [[0, 3, 5, 8], [11, 13, -1, -1]] -> [[0, 3, 5, 8], [11, 13, b*s, b*s]]
q_seq_offsets = _fix_len_take(
q_seq_offsets, q_seq_offsets >= 0, fill_value=q_batch * q_max_seqlen
)
k_seq_offsets = _fix_len_take(
k_seq_offsets, k_seq_offsets >= 0, fill_value=kv_batch * kv_max_seqlen
)
q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten())
kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten())
dq, dk, dv, dbias, _ = FusedAttnBwdPrimitive.inner_primitive.bind(
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
config=config,
)
return dq, dk, dv, dbias
@staticmethod
def batcher(batched_args, batch_dims, *, config):
check_valid_batch_dims(batch_dims)
assert FusedAttnBwdPrimitive.outer_primitive is not None
q_bdim, k_bdim, v_bdim, *_ = batch_dims
out_bdims = q_bdim, k_bdim, v_bdim, q_bdim
return (
FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args, config=config),
out_bdims,
)
@staticmethod
def infer_sharding_from_operands(config, mesh, arg_infos, result_infos):
del config, result_infos
q_spec = get_padded_spec(arg_infos[0])
k_spec = get_padded_spec(arg_infos[1])
v_spec = get_padded_spec(arg_infos[2])
bias_spec = get_padded_spec(arg_infos[3])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)
@staticmethod
def partition(config, mesh, arg_infos, result_infos):
del result_infos
q_spec = get_padded_spec(arg_infos[0])
k_spec = get_padded_spec(arg_infos[1])
v_spec = get_padded_spec(arg_infos[2])
bias_spec = get_padded_spec(arg_infos[3])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[-1] = arg_shardings[-3]
arg_shardings[-2] = arg_shardings[-4]
arg_shardings = tuple(arg_shardings)
out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)
def sharded_impl(
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
):
local_dq, local_dk, local_dv, local_dbias = FusedAttnBwdPrimitive.impl(
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_cu_seqlen,
kv_cu_seqlen,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
config=config,
)
global_dbias = local_dbias
if config.attn_bias_type is not AttnBiasType.NO_BIAS:
global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh)
return local_dq, local_dk, local_dv, global_dbias
return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(config, mesh, value_types, result_types):
del config, mesh
# We only care about the four first arguments.
# Keep in sync with `infer_sharding_from_operands`.
input_spec = tuple((f"…{x}",) for x in range(len(value_types)))
output_spec = tuple((f"…{x}",) for x in range(len(result_types)))
return SdyShardingRule(input_spec, output_spec)
register_primitive(FusedAttnBwdPrimitive)
def reorder_causal_dual_chunk_swap(tensor, cp_size: int, seq_dim: int, to_contiguous: bool):
"""Reorders a tensor for load balancing the compute of causal attention."""
if cp_size == 1:
return tensor
if cp_size % 2 != 0:
raise ValueError(f"{cp_size=} must be a multiple of 2.")
# Need to ensure we have 2 pairs to swap for balancing between cp ranks
if tensor.shape[seq_dim] % (cp_size * 2) != 0:
raise ValueError(f"{tensor.shape[seq_dim]=} is not a multiple of {cp_size*2=}")
# [B, S, H, D] -> [B, 2*cp_size, S/2*cp_size, D]
# [S, B, H, D] -> [2*cp_size, S/2*cp_size, B, H, D]
ori_tensor_shape = tensor.shape
tensor = tensor.reshape(
(
*ori_tensor_shape[:seq_dim],
2 * cp_size,
ori_tensor_shape[seq_dim] // (2 * cp_size),
*ori_tensor_shape[seq_dim + 1 :],
)
)
parts = []
if not to_contiguous:
for cp_rank in range(cp_size):
# [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
# [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
index = jnp.array([cp_rank, (2 * cp_size - cp_rank - 1)])
parts.append(jnp.take(tensor, index, axis=seq_dim))
else:
for cp_rank in range(cp_size // 2):
# [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
# [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
base = 4 * cp_rank
index = jnp.array([base, base + 2])
parts.append(jnp.take(tensor, index, axis=seq_dim))
for cp_rank in range(cp_size // 2):
# [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D] -> [B, 2, S/2*cp_size, H, D]
# [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D] -> [2, S/2*cp_size, B, H, D]
base = 2 * cp_size - 1 - 4 * cp_rank
index = jnp.array([base, base - 2])
parts.append(jnp.take(tensor, index, axis=seq_dim))
# [B, S, H, D]: [B, 2*cp_size, S/2*cp_size, H, D]
# [S, B, H, D]: [2*cp_size, S/2*cp_size, B, H, D]
combined = jnp.stack(parts, axis=seq_dim)
return combined.reshape(ori_tensor_shape)
def reorder_causal_striped(tensor, cp_size: int, seq_dim: int, is_inverse: bool):
"""Reorders a tensor for load balancing with striped pattern"""
origin_shape = tensor.shape
if origin_shape[seq_dim] % cp_size != 0:
raise ValueError(
"Expected origin_shape[seq_dim] is multiple of cp_size but got"
f" {origin_shape[seq_dim]=} and {cp_size=}"
)
if not is_inverse:
new_shape = [
*origin_shape[:seq_dim],
*[origin_shape[seq_dim] // cp_size, cp_size],
*origin_shape[seq_dim + 1 :],
]
else:
new_shape = [
*origin_shape[:seq_dim],
*[cp_size, origin_shape[seq_dim] // cp_size],
*origin_shape[seq_dim + 1 :],
]
chunked_tensor = tensor.reshape(new_shape)
reordered_chunked_tensor = jnp.swapaxes(chunked_tensor, seq_dim, seq_dim + 1)
return reordered_chunked_tensor.reshape(origin_shape)
@dataclass(frozen=True)
class _FusedAttnCPWithAllGatherHelper:
"""Helper class to assist with running the all-gather strategy for CP attention."""
mesh: jax.sharding.Mesh
config: _FusedAttnConfig
def check_supported(self):
"""Checks if the context parallel implementation is supported by the given arguments."""
header = "Context parallel fused attention"
allowed_layouts = [QKVLayout.BSHD_BS2HD, QKVLayout.BSHD_BSHD_BSHD]
if self.config.qkv_layout not in allowed_layouts:
raise ValueError(
f"{header} only supports layouts:"
f" {','.join(map(str, allowed_layouts))} got: {self.config.qkv_layout}"
)
if self.config.attn_bias_type != AttnBiasType.NO_BIAS:
raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}")
allowed_masks = [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]
if self.config.attn_mask_type not in allowed_masks:
raise ValueError(
f"{header} only supports masking types: "
f" {','.join(map(str, allowed_masks))} got: {self.config.attn_mask_type}"
)
if self.config.max_segments_per_seq != 1:
raise ValueError(
f"{header} only supports max_segments_per_seq == 1 got:"
f" {self.config.max_segments_per_seq}"
)
if self.config.dropout_probability != 0.0:
raise ValueError(f"{header} does not support dropout")
def get_adjusted_mask(self):
"""Converts the mask for context parallelism."""
if self.config.attn_mask_type == AttnMaskType.CAUSAL_MASK:
return AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK
return self.config.attn_mask_type
def get_step_config(self) -> _FusedAttnConfig:
"""Returns a _FusedAttnConfig for single CP step call to fused attention."""
return _FusedAttnConfig(
attn_bias_type=self.config.attn_bias_type,
attn_mask_type=self.get_adjusted_mask(),
qkv_layout=self.config.qkv_layout,
scaling_factor=self.config.scaling_factor,
dropout_probability=self.config.dropout_probability,
is_training=self.config.is_training,
max_segments_per_seq=self.config.max_segments_per_seq,
window_size=self.config.window_size,
context_parallel_load_balanced=self.config.context_parallel_load_balanced,
cp_axis=self.config.cp_axis,
cp_striped_window_size=None,
)
def all_gather_kv(self, k, v):
"""Performs a all-gather of k and v over context parallel ranks."""
def ag(x):
x = lax_paral_op(
x, lax.all_gather, self.config.cp_axis, mesh=self.mesh, axis=1, tiled=True
)
if self.config.context_parallel_load_balanced:
cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh)
x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=True)
return x
if self.config.qkv_layout.is_kvpacked():
return ag(k), v
if self.config.qkv_layout.is_separate():
return ag(k), ag(v)
return k, v # fall through
def reduce_scatter_dkv(self, dk, dv):
"""Performs a reduce-scatter of dk and dv over context parallel ranks."""
def rs(x):
if self.config.context_parallel_load_balanced:
cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh)
x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=False)
return lax_paral_op(
x,
lax.psum_scatter,
self.config.cp_axis,
mesh=self.mesh,
scatter_dimension=1,
tiled=True,
)
if self.config.qkv_layout.is_kvpacked():
return rs(dk), dv
if self.config.qkv_layout.is_separate():
return rs(dk), rs(dv)
return dk, dv # fall through
def kv_seqlens_for_rank(self, cp_rank, kv_max_seqlen, kv_seqlen_per_subrank):
"""Returns sequence lengths of KV to use for each sub rank of the given cp_rank.
Example: CP=4, MaxLen = 1024, Unbalanced
cp_rank 0: [128, 256]
cp_rank 1: [384, 512]
cp_rank 2: [640, 768]
cp_rank 3: [896, 1024]
Example: CP=4, MaxLen = 1024, Balanced
cp_rank 0: [128, 1024]
cp_rank 1: [256, 896]
cp_rank 2: [384, 768]
cp_rank 3: [512, 640]
"""
if self.config.context_parallel_load_balanced:
kv_seq_this_rank = [
(cp_rank + 1) * kv_seqlen_per_subrank,
kv_max_seqlen - cp_rank * kv_seqlen_per_subrank,
]
else:
kv_seq_this_rank = [
(cp_rank * 2 + 1) * kv_seqlen_per_subrank,
(cp_rank * 2 + 2) * kv_seqlen_per_subrank,
]
return kv_seq_this_rank
def slice_kv(self, k, v, slice_seq_len):
"""Slices k and v tensors to a sequence length of slice_seq_len."""
def sliced(x):
return lax.dynamic_slice_in_dim(x, 0, slice_seq_len, axis=1)
if self.config.qkv_layout.is_kvpacked():
return sliced(k), v
if self.config.qkv_layout.is_separate():
return sliced(k), sliced(v)
return k, v # fall through
def pad_kv(self, dk, dv, pad_seq_len):
"""Pads dk and dv tensors to a sequence length of pad_seq_len."""
def pad(x, npad):
return jnp.pad(x, npad, "constant", constant_values=0.0)
if self.config.qkv_layout.is_kvpacked():
npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0], [0, 0]]
return pad(dk, npad), dv
if self.config.qkv_layout.is_separate():
npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0]]
return pad(dk, npad), pad(dv, npad)
return dk, dv # fall through
class FusedAttnCPWithAllGatherFwdPrimitive(FusedAttnFwdPrimitive):
"""
Fused Attention Forward with Context Parallelism Primitive
This context parallel implementation uses all-gather to collect KV inputs from context parallel ranks.
"""
@staticmethod
def partition(config, mesh, arg_infos, result_infos):
# Call base implementation for non-context parallel mesh to avoid unecessary work.
is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
assert (
not is_context_parallel or config.window_size[0] == -1
), "Sliding window attention is not supported when context parallelism is enabled"
if not is_context_parallel:
return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos)
helper = _FusedAttnCPWithAllGatherHelper(mesh, config)
helper.check_supported()
out_sharding = result_infos[0].sharding
softmax_aux_sharding = result_infos[1].sharding
rng_state_sharding = seed_sharding = NamedSharding(
mesh, PartitionSpec(get_all_mesh_axes(), None)
)
arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[4] = seed_sharding
arg_shardings = tuple(arg_shardings)
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
def impl(
q,
k,
v,
bias,
seed,
q_seqlen,
kv_seqlen,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
):
cp_size = get_mesh_axis_size(config.cp_axis, mesh)
cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)
# cuDNN does not support right-aligned masking with dynamic sequence length padding.
# Therefore we must explicitly instantiate each CP rank slicing and use a runtime switch
# to select the appropriate computation. Each case generates a [..., SEQ/CP, ..] tensor
# meeting the expectation of the SPMD model.
# TODO(mgoldfarb-nvidia): When cuDNN supports we should be able to make use of a padding
# mask/sequence length tensor to avoid this unrolled loop.
def _cross_attn(idx, q, k, v, bias, q_seqlen, kv_seqlen, seed):
kv_max_seqlen = k.shape[1]
kv_seqlen_per_subrank = kv_max_seqlen // (cp_size * 2)
assert kv_max_seqlen % cp_size == 0, "sequence length must evenly divide cp size"
q_split = jnp.split(q, 2, axis=1)
kv_seqlens_for_rank = helper.kv_seqlens_for_rank(
idx, kv_max_seqlen, kv_seqlen_per_subrank
)
results = []
for sub_idx in range(2):
if config.attn_mask_type == AttnMaskType.NO_MASK:
k_unmasked, v_unmasked = k, v # full kv used for unmasked
else:
k_unmasked, v_unmasked = helper.slice_kv(k, v, kv_seqlens_for_rank[sub_idx])
q_seqlen_for_step = q_seqlen / (cp_size * 2)
num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx]
kv_seqlen_for_step = (kv_seqlen / (cp_size * 2)) * num_kv_chunks
output, softmax_aux, rng_state = FusedAttnFwdPrimitive.impl(
q_split[sub_idx],
k_unmasked,
v_unmasked,
bias,
seed,
q_seqlen_for_step,
kv_seqlen_for_step,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
config=helper.get_step_config(),
)
results.append((output, softmax_aux, rng_state))
output = jnp.concatenate((results[0][0], results[1][0]), axis=1)
softmax_aux = jnp.concatenate((results[0][1], results[1][1]), axis=2)
rng_state = results[1][2] # Use the final RNG state
return output, softmax_aux, rng_state
k_ag, v_ag = helper.all_gather_kv(k, v)
functions = [
partial(_cross_attn, idx, q, k_ag, v_ag, bias, q_seqlen, kv_seqlen, seed)
for idx in range(cp_size)
]
return lax.switch(cp_rank, functions)
return mesh, impl, out_shardings, arg_shardings
register_primitive(FusedAttnCPWithAllGatherFwdPrimitive)
class FusedAttnCPWithAllGatherBwdPrimitive(FusedAttnBwdPrimitive):
"""
Fused Attention Backward with Context Parallelism Primitive.
This context parallel implementation uses all-gather to collect KV and dKV inputs from context parallel ranks.
The gradients are subsequently reduce-scattered back to each context parallel rank.
"""
@staticmethod
def partition(config, mesh, arg_infos, result_infos):
# Call base implementation for non-context parallel mesh to avoid unecessary work.
is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
assert (
not is_context_parallel or config.window_size[0] == -1
), "Sliding window attention is not supported when context parallelism is enabled"
if not is_context_parallel:
return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)
# Ensure we can support this configuration with context parallelism.
helper = _FusedAttnCPWithAllGatherHelper(mesh, config)
helper.check_supported()
del result_infos
q_spec = get_padded_spec(arg_infos[0])
k_spec = get_padded_spec(arg_infos[1])
v_spec = get_padded_spec(arg_infos[2])
bias_spec = get_padded_spec(arg_infos[3])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)
def impl(
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_seqlen,
kv_seqlen,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
):
cp_size = get_mesh_axis_size(config.cp_axis, mesh)
cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)
# See comment in FusedAttnCPFwdPrimitive.partition for why we define this function.
def _cross_attn_bwd(
idx,
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_seqlen,
kv_seqlen,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
):
kv_max_seqlen = k.shape[1]
kv_seqlen_per_subrank = kv_max_seqlen // (cp_size * 2)
assert kv_max_seqlen % cp_size == 0, "sequence length must evenly divide cp size"
q_split = jnp.split(q, 2, axis=1)
output_split = jnp.split(output, 2, axis=1)
doutput_split = jnp.split(doutput, 2, axis=1)
softmax_aux_split = jnp.split(softmax_aux, 2, axis=2)
kv_seqlens_for_rank = helper.kv_seqlens_for_rank(
idx, kv_max_seqlen, kv_seqlen_per_subrank
)
results = []
for sub_idx in range(2):
if config.attn_mask_type == AttnMaskType.NO_MASK:
k_unmasked, v_unmasked = k, v # full kv used for unmasked
else:
k_unmasked, v_unmasked = helper.slice_kv(k, v, kv_seqlens_for_rank[sub_idx])
q_seqlen_for_step = q_seqlen // (cp_size * 2)
num_kv_chunks = kv_max_seqlen // kv_seqlens_for_rank[sub_idx]
kv_seqlen_for_step = (kv_seqlen // (cp_size * 2)) * num_kv_chunks
dq_local, dk_local, dv_local, dbias_local = FusedAttnBwdPrimitive.impl(
q_split[sub_idx],
k_unmasked,
v_unmasked,
bias,
softmax_aux_split[sub_idx],
rng_state,
output_split[sub_idx],
doutput_split[sub_idx],
q_seqlen_for_step,
kv_seqlen_for_step,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
config=helper.get_step_config(),
)
# pad dk/dv to be unsliced shape so we can reduce scatter over all ranks.
if config.attn_mask_type != AttnMaskType.NO_MASK:
pad_length = kv_max_seqlen - kv_seqlens_for_rank[sub_idx]
dk_local, dv_local = helper.pad_kv(dk_local, dv_local, pad_length)
results.append((dq_local, dk_local, dv_local, dbias_local))
dq_local = jnp.concatenate((results[0][0], results[1][0]), axis=1)
dk_local_pad = results[0][1] + results[1][1]
dv_local_pad = results[0][2] + results[1][2]
return dq_local, dk_local_pad, dv_local_pad, results[1][3]
k_ag, v_ag = helper.all_gather_kv(k, v)
functions = [
partial(
_cross_attn_bwd,
idx,
q,
k_ag,
v_ag,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_seqlen,
kv_seqlen,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
)
for idx in range(cp_size)
]
dq, dk_local, dv_local, dbias = lax.switch(cp_rank, functions)
dk, dv = helper.reduce_scatter_dkv(dk_local, dv_local)
return dq, dk, dv, dbias
return mesh, impl, out_shardings, arg_shardings
register_primitive(FusedAttnCPWithAllGatherBwdPrimitive)
@dataclass(frozen=True)
class _FusedAttnCPWithP2PHelper:
"""Helper class to assist with running the P2P ring strategy for CP attention."""
mesh: jax.sharding.Mesh
config: _FusedAttnConfig
@staticmethod
def use_scanloop():
"""Returns true if the implementation will use a scan loop for iteration."""
use_scan = bool(int(os.getenv("NVTE_FUSED_RING_ATTENTION_USE_SCAN", "1")))
return use_scan
def check_supported(self):
"""Checks if the context parallel implementation is supported by the given arguments."""
header = "Context parallel fused ring attention"
if self.config.qkv_layout.is_thd():
allowed_layouts = [QKVLayout.THD_T2HD, QKVLayout.THD_THD_THD]
else:
allowed_layouts = [QKVLayout.BSHD_BS2HD, QKVLayout.BSHD_BSHD_BSHD]
if self.config.qkv_layout not in allowed_layouts:
raise ValueError(
f"{header} only supports layouts:"
f" {','.join(map(str, allowed_layouts))} got: {self.config.qkv_layout}"
)
if self.config.attn_bias_type != AttnBiasType.NO_BIAS:
raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}")
if self.config.qkv_layout.is_thd():
allowed_masks = [AttnMaskType.PADDING_CAUSAL_MASK]
else:
allowed_masks = [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]
if self.config.attn_mask_type not in allowed_masks:
raise ValueError(
f"{header} only supports masking types: "
f" {','.join(map(str, allowed_masks))} got: {self.config.attn_mask_type}"
)
if not self.config.qkv_layout.is_thd() and self.config.max_segments_per_seq != 1:
raise ValueError(
f"{header} only supports max_segments_per_seq == 1 got:"
f" {self.config.max_segments_per_seq}"
)
if self.config.dropout_probability != 0.0:
raise ValueError(f"{header} does not support dropout")
# We want to encourage use of scan loop to minimize unrolling and ensure more
# predictable scheduling from XLA. The unrolled flavor will be supported but
# not the prefered implementation.
if not self.use_scanloop():
warnings.warn(
"Scan loop is disabled for fused ring attention. To enable set"
" NVTE_FUSED_RING_ATTENTION_USE_SCAN=1 in your environment"
)
# If using scanloop, idx in scan_kv_block() will be a traced device value, but
# _normalize_window_size_for_cp_striped() requires all parameters to be host values
is_context_parallel = get_mesh_axis_size(self.config.cp_axis, self.mesh) > 1
is_thd_layout = self.config.qkv_layout.is_thd()
is_sliding_window = self.config.window_size[0] != -1
if is_context_parallel and is_thd_layout and is_sliding_window and self.use_scanloop():
raise ValueError(
f"{header} with THD format and sliding window does not support using scan loop"
)
def get_step_config(self, attn_mask_type) -> _FusedAttnConfig:
"""Returns a _FusedAttnConfig for single CP step call to fused attention."""
return _FusedAttnConfig(
attn_bias_type=self.config.attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=QKVLayout.BSHD_BS2HD,
scaling_factor=self.config.scaling_factor,
dropout_probability=self.config.dropout_probability,
is_training=self.config.is_training,
max_segments_per_seq=self.config.max_segments_per_seq,
window_size=self.config.window_size,
context_parallel_load_balanced=self.config.context_parallel_load_balanced,
cp_axis=self.config.cp_axis,
cp_striped_window_size=None,
)
def stack_kv(self, k, v):
"""Stacks k and v tensors if not stacked."""
_not_used = jnp.zeros(0, dtype=k.dtype)
if self.config.qkv_layout.is_kvpacked():
return k
if self.config.qkv_layout.is_separate():
return jnp.stack([k, v], axis=2)
return _not_used
def unstack_kv(self, kv):
"""Un-stacks k and v tensors if not stacked."""
_not_used = jnp.zeros(0, dtype=kv.dtype)
if self.config.qkv_layout.is_kvpacked():
return kv, _not_used
if self.config.qkv_layout.is_separate():
return jnp.unstack(kv, axis=2)
return _not_used, _not_used # fall through
def permute_kv(self, kv, cp_perm):
"""Permutes kv around the ring as described by cp_perm."""
return lax_paral_op(kv, lax.ppermute, self.config.cp_axis, mesh=self.mesh, perm=cp_perm)
@staticmethod
def correct_output_and_softmax_aux(output, softmax_aux, partial_output, partial_softmax_aux):
"""
Corrects the output and softmax_aux tensor after each iteration of ring attention.
See https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 for
derivation of this equation.
"""
new_out = output - jax.nn.sigmoid(partial_softmax_aux - softmax_aux).transpose(
0, 2, 1, 3
) * (output - partial_output)
new_aux = softmax_aux - jax.nn.log_sigmoid(softmax_aux - partial_softmax_aux)
return new_out, new_aux
def adjust_seqlen(self, seqlen, max_seqlen, idx):
"""Adjust the sequence length per step."""
seqlen_of_curr_step = seqlen - max_seqlen * idx
seqlen_of_curr_step = jnp.where(seqlen_of_curr_step < 0, 0, seqlen_of_curr_step)
seqlen_per_step = jnp.where(
seqlen_of_curr_step < max_seqlen, seqlen_of_curr_step, max_seqlen
)
return seqlen_per_step
class FusedRingAttnFwdPrimitive(FusedAttnFwdPrimitive):
"""
Fused Ring Attention Forward Primitive
"""
@staticmethod
def partition(config, mesh, arg_infos, result_infos):
is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
assert (
not is_context_parallel or config.window_size[0] == -1
), "Sliding window attention is not supported when context parallelism is enabled"
if not is_context_parallel:
return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos)
helper = _FusedAttnCPWithP2PHelper(mesh, config)
helper.check_supported()
out_sharding = result_infos[0].sharding
softmax_aux_sharding = result_infos[1].sharding
rng_state_sharding = seed_sharding = NamedSharding(
mesh, PartitionSpec(get_all_mesh_axes(), None)
)
arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[4] = seed_sharding
arg_shardings = tuple(arg_shardings)
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
def ring_attn_fwd_impl(
q,
k,
v,
bias,
seed,
q_seqlen,
kv_seqlen,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
):
_not_used = jnp.zeros(0, dtype=v.dtype)
# Combine KV tensors if separate for better permute scheduling and performance.
# Eventually XLA should perform this automatically.
kv = helper.stack_kv(k, v)
batch, q_max_seqlen, head, _ = q.shape
kv_max_seqlen = k.shape[1]
cp_size = get_mesh_axis_size(config.cp_axis, mesh)
cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)
cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]
output = jnp.zeros(q.shape).astype(jnp.float32)
softmax_aux = jnp.full((batch, head, q_max_seqlen, 1), -jnp.inf, dtype=jnp.float32)
# RNG shape should be the shared shape. This is unused for ring attention as we do not
# support dropout currently.
rng_state = jnp.zeros(result_infos[2].shape).astype(result_infos[2].dtype)
def scan_kv_block(idx, carry):
kv, output, softmax_aux = carry
# Send KV block to next step so we can overlap compute.
kv_next = helper.permute_kv(kv, cp_perm)
def mask_compute(attn_mask_type):
q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx)
output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl(
q,
kv,
_not_used,
bias,
seed,
q_seqlen_per_step,
kv_seqlen_per_step,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
config=helper.get_step_config(attn_mask_type),
)
return output_per_step, softmax_aux_per_step
causal_mask_compute = partial(mask_compute, AttnMaskType.CAUSAL_MASK)
no_mask_compute = partial(mask_compute, AttnMaskType.NO_MASK)
def half_kv_no_mask_compute():
q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) // 2
kv_part = lax.slice_in_dim(kv, 0, kv.shape[1] // 2, axis=1)
output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl(
q,
kv_part,
_not_used,
bias,
seed,
q_seqlen_per_step,
kv_seqlen_per_step,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
config=helper.get_step_config(AttnMaskType.NO_MASK),
)
return output_per_step, softmax_aux_per_step
def half_q_no_mask_compute():
q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) // 2
kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx)
q_part = lax.slice_in_dim(q, q_max_seqlen // 2, q_max_seqlen, axis=1)
output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl(
q_part,
kv,
_not_used,
bias,
seed,
q_seqlen_per_step,
kv_seqlen_per_step,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
config=helper.get_step_config(AttnMaskType.NO_MASK),
)
output_per_step = jnp.concat([jnp.zeros_like(q_part), output_per_step], axis=1)
softmax_aux_per_step = jnp.concat(
[
jnp.full_like(softmax_aux_per_step, -jnp.inf),
softmax_aux_per_step,
],
axis=2,
)
return output_per_step, softmax_aux_per_step
def skip_compute():
output_per_step = jnp.zeros_like(q)
softmax_aux_per_step = jnp.full(
(batch, head, q.shape[1], 1), -jnp.inf, dtype=jnp.float32
)
return output_per_step, softmax_aux_per_step
if config.attn_mask_type == AttnMaskType.CAUSAL_MASK:
# This is for nested jax.lax.cond
def jax_cond_wrap():
if config.context_parallel_load_balanced:
return lax.cond(
(idx <= cp_rank), half_kv_no_mask_compute, half_q_no_mask_compute
)
return lax.cond((idx <= cp_rank), no_mask_compute, skip_compute)
output_per_step, softmax_aux_per_step = lax.cond(
idx == 0, causal_mask_compute, jax_cond_wrap
)
else:
output_per_step, softmax_aux_per_step = no_mask_compute()
def skip_correction(output, softmax_aux, output_per_step, softmax_aux_per_step):
# No correction done here but we cast outputs to float32 and perform reduction
# in full precision.
# pylint: disable=unused-argument
return output_per_step.astype(jnp.float32), softmax_aux_per_step
def correction(output, softmax_aux, output_per_step, softmax_aux_per_step):
return helper.correct_output_and_softmax_aux(
output, softmax_aux, output_per_step, softmax_aux_per_step
)
# first step there is no correction we get initial output and stats
output, softmax_aux = lax.cond(
(idx == 0),
skip_correction,
correction,
output,
softmax_aux,
output_per_step,
softmax_aux_per_step,
)
return (kv_next, output, softmax_aux)
carry = (kv, output, softmax_aux)
if helper.use_scanloop():
carry = lax.fori_loop(0, cp_size, scan_kv_block, carry)
else:
for i in range(0, cp_size):
carry = scan_kv_block(i, carry)
(kv, output, softmax_aux) = carry
output = output.astype(q.dtype)
return output, softmax_aux, rng_state
return mesh, ring_attn_fwd_impl, out_shardings, arg_shardings
register_primitive(FusedRingAttnFwdPrimitive)
class FusedRingAttnBwdPrimitive(FusedAttnBwdPrimitive):
"""
Fused Ring Attention Backward Primitive
"""
@staticmethod
def partition(config, mesh, arg_infos, result_infos):
is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
assert (
not is_context_parallel or config.window_size[0] == -1
), "Sliding window attention is not supported when context parallelism is enabled"
if not is_context_parallel:
return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)
del result_infos
q_spec = get_padded_spec(arg_infos[0])
k_spec = get_padded_spec(arg_infos[1])
v_spec = get_padded_spec(arg_infos[2])
bias_spec = get_padded_spec(arg_infos[3])
dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec))
dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec))
dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding)
helper = _FusedAttnCPWithP2PHelper(mesh, config)
helper.check_supported()
def ring_attn_bwd_impl(
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_seqlen,
kv_seqlen,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
):
_not_used = jnp.zeros(0, dtype=output.dtype)
# Combine KV tensors if separate for better permute scheduling and performance.
# Eventually XLA should perform this automatically.
kv = helper.stack_kv(k, v)
q_max_seqlen = q.shape[1]
kv_max_seqlen = k.shape[1]
cp_size = get_mesh_axis_size(config.cp_axis, mesh)
cp_rank = get_mesh_axis_rank(config.cp_axis, mesh)
cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]
dq = jnp.zeros_like(q)
dk_dv = helper.stack_kv(jnp.zeros_like(k), jnp.zeros_like(v))
dbias = jnp.zeros_like(bias)
def scan_kv_block(idx, carry):
kv, dq, dk_dv, dbias = carry
# Start communication that feeds the next iteraton.
# We further combine the tensors to improve overlap.
kv_dk_dv = jnp.stack([kv, dk_dv])
kv_dk_dv = helper.permute_kv(kv_dk_dv, cp_perm)
def mask_compute(attn_mask_type):
q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx)
dq_per_step, dk_dv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl(
q,
kv,
_not_used,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_seqlen_per_step,
kv_seqlen_per_step,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
config=helper.get_step_config(attn_mask_type),
)
return dq_per_step, dk_dv_per_step, dbias_per_step
causal_mask_compute = partial(mask_compute, AttnMaskType.CAUSAL_MASK)
no_mask_compute = partial(mask_compute, AttnMaskType.NO_MASK)
def half_kv_no_mask_compute():
q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx)
kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx) // 2
kv_part = lax.slice_in_dim(kv, 0, kv_max_seqlen // 2, axis=1)
dq_per_step, dk_dv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl(
q,
kv_part,
_not_used,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_seqlen_per_step,
kv_seqlen_per_step,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
config=helper.get_step_config(AttnMaskType.NO_MASK),
)
dk_dv_per_step = jnp.concat(
[dk_dv_per_step, jnp.zeros_like(dk_dv_per_step)], axis=1
)
return dq_per_step, dk_dv_per_step, dbias_per_step
def half_q_no_mask_compute():
q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) // 2
kv_seqlen_per_step = helper.adjust_seqlen(kv_seqlen, kv_max_seqlen, idx)
q_part = lax.slice_in_dim(q, q_max_seqlen // 2, q_max_seqlen, axis=1)
doutput_part = lax.slice_in_dim(
doutput, q_max_seqlen // 2, q_max_seqlen, axis=1
)
output_part = lax.slice_in_dim(output, q_max_seqlen // 2, q_max_seqlen, axis=1)
softmax_aux_part = lax.slice_in_dim(
softmax_aux, q_max_seqlen // 2, q_max_seqlen, axis=2
)
dq_per_step, dk_dv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl(
q_part,
kv,
_not_used,
bias,
softmax_aux_part,
rng_state,
output_part,
doutput_part,
q_seqlen_per_step,
kv_seqlen_per_step,
q_seq_offsets,
k_seq_offsets,
_q_segment_ids,
_kv_segment_ids,
_q_segment_pos,
_kv_segment_pos,
config=helper.get_step_config(AttnMaskType.NO_MASK),
)
dq_per_step = jnp.concat([jnp.zeros_like(dq_per_step), dq_per_step], axis=1)
return dq_per_step, dk_dv_per_step, dbias_per_step
def skip_compute():
return jnp.zeros_like(q), jnp.zeros_like(kv), jnp.zeros_like(bias)
if config.attn_mask_type == AttnMaskType.CAUSAL_MASK:
# This is for nested jax.lax.cond
def jax_cond_wrap():
if config.context_parallel_load_balanced:
return lax.cond(
(idx <= cp_rank), half_kv_no_mask_compute, half_q_no_mask_compute
)
return lax.cond((idx <= cp_rank), no_mask_compute, skip_compute)
dq_per_step, dk_dv_per_step, dbias_per_step = lax.cond(
idx == 0, causal_mask_compute, jax_cond_wrap
)
else:
dq_per_step, dk_dv_per_step, dbias_per_step = no_mask_compute()
kv_next, dk_dv = jnp.unstack(kv_dk_dv)
dq = dq + dq_per_step
dk_dv = dk_dv + dk_dv_per_step
if config.attn_bias_type is not AttnBiasType.NO_BIAS:
dbias = dbias + dbias_per_step
return (kv_next, dq, dk_dv, dbias)
carry = (kv, dq, dk_dv, dbias)
if helper.use_scanloop():
carry = lax.fori_loop(0, cp_size, scan_kv_block, carry)
else:
for i in range(0, cp_size):
carry = scan_kv_block(i, carry)
(kv, dq, dk_dv, dbias) = carry
# Final permute to put gradients back to their final resting place.
dk_dv = helper.permute_kv(dk_dv, cp_perm)
global_dbias = dbias
if config.attn_bias_type is not AttnBiasType.NO_BIAS:
global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh)
dk, dv = helper.unstack_kv(dk_dv)
return dq, dk, dv, global_dbias
return mesh, ring_attn_bwd_impl, out_shardings, arg_shardings
register_primitive(FusedRingAttnBwdPrimitive)
def adjust_cp_striped_window_size(q_pos0, kv_pos0, cp_size, window_size):
"""
Adjust window size with cp_size for striped sharding, where both q_pos and
kv_pos are arithmetic sequences like [x, x+cp_size, x+2*cp_size, ...].
Example 1:
q_pos = kv_pos = [0, 8, 16, 24, 32], cp_size = 8, window_size = (15, 0).
q_pos = 32 can look at kv_pos at [24, 32]. The effective mask is:
0 8 16 24 32
----------------
0 | 1 0 0 0 0
8 | 1 1 0 0 0
16 | 0 1 1 0 0
24 | 0 0 1 1 0
32 | 0 0 0 1 1
SequenceDescriptor outputs: {q,kv}_seqlen = [5, ...], {q,kv}_seq_offsets = [0, ...].
Adjusted window size = (1, 0).
Example 2:
q_pos = [0, 8, 16, 24, 32], kv_pos = [1, 9, 17, 25, 33], cp_size = 8,
window_size = (15, 0). The effective mask is:
1 9 17 25 33
----------------
0 | 0 0 0 0 0
8 | 1 0 0 0 0
16 | 1 1 0 0 0
24 | 0 1 1 0 0
32 | 0 0 1 1 0
SequenceDescriptor outputs:
q_seqlen = [4, ...], q_seq_offsets = [1, ...],
kv_seqlen = [4, ...], kv_seq_offsets = [0, ...].
If diagonal are all 1, left window size = 2. Now since diagonal are all 0,
we need to use left window size = 2 - 1 = 1 to make cuDNN work.
Example 3:
q_pos = [7, 15, 23, 31, 39], kv_pos = [0, 8, 16, 24, 32], cp_size = 8,
window_size = (22, 0). The effective mask is:
0 8 16 24 32
----------------
7 | 1 0 0 0 0
15 | 1 1 0 0 0
23 | 0 1 1 0 0
31 | 0 0 1 1 0
39 | 0 0 0 1 1
SequenceDescriptor outputs: {q,kv}_seqlen = [5, ...], {q,kv}_seq_offsets = [0, ...].
Adjust window size = (1, 0).
"""
left_limit = q_pos0 - window_size[0]
right_limit = q_pos0 + window_size[1]
# Count how many left/right steps of size cp_size we can take from kv_pos0 -/+ cp_size
left_steps = (kv_pos0 - cp_size - left_limit) // cp_size + 1
right_steps = (right_limit - kv_pos0 - cp_size) // cp_size + 1
left_steps = max(left_steps, 0)
right_steps = max(right_steps, 0)
# If kv_pos0 > q_pos0, we must reduce left window size by 1
shift = 1 if kv_pos0 > q_pos0 else 0
left_steps = left_steps - shift
return left_steps, right_steps
class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive):
"""
Fused Striped Ring Attention Forward Primitive
"""
@staticmethod
def partition(config, mesh, arg_infos, result_infos):
is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
if not is_context_parallel:
return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos)
helper = _FusedAttnCPWithP2PHelper(mesh, config)
helper.check_supported()
out_sharding = result_infos[0].sharding
softmax_aux_sharding = result_infos[1].sharding
rng_state_sharding = seed_sharding = NamedSharding(
mesh, PartitionSpec(get_all_mesh_axes(), None)
)
arg_shardings = [arg_i.sharding for arg_i in arg_infos]
arg_shardings[4] = seed_sharding
arg_shardings = tuple(arg_shardings)
out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding)
def fwd_impl(
q,
k,
v,
bias,
seed,
q_seqlen,
kv_seqlen,
q_seq_offsets,
k_seq_offsets,
q_segment_ids,
kv_segment_ids,
q_segment_pos,
kv_segment_pos,
):
if q_segment_ids.size == 0 or kv_segment_ids.size == 0:
raise ValueError("THD + ring attn only supports passing seqment_ids/pos")
_not_used = jnp.zeros(0, dtype=v.dtype)
# Combine KV tensors if separate for better permute scheduling and performance.
# Eventually XLA should perform this automatically.
kv = helper.stack_kv(k, v)
if not config.qkv_layout.is_qkvpacked():
subblock_config = replace(config, qkv_layout=config.qkv_layout.to_kvpacked())
else:
subblock_config = config
cp_size = get_mesh_axis_size(config.cp_axis, mesh)
cp_rank = get_mesh_axis_rank_host(config.cp_axis, mesh)
cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]
batch, q_max_seqlen, head, _ = q.shape
output = jnp.zeros(q.shape).astype(jnp.float32)
softmax_aux = jnp.zeros((batch, q_max_seqlen, head, 1), dtype=jnp.float32)
# RNG shape should be the shared shape. This is unused for ring attention as we do not
# support dropout currently.
rng_state = jnp.zeros(result_infos[2].shape).astype(result_infos[2].dtype)
def scan_kv_block(idx, carry):
kv, kv_segment_ids, kv_segment_pos, output, softmax_aux = carry
# TODO(rewang): To check whether we need special handle for the last idx
# Send KV block to next step so we can overlap compute.
kv_next = helper.permute_kv(kv, cp_perm)
kv_segment_ids_next = helper.permute_kv(kv_segment_ids, cp_perm)
kv_segment_pos_next = helper.permute_kv(kv_segment_pos, cp_perm)
def compute(config):
return FusedAttnFwdPrimitive.impl(
q,
kv,
_not_used,
bias,
seed,
q_seqlen,
kv_seqlen,
q_seq_offsets,
k_seq_offsets,
q_segment_ids,
kv_segment_ids,
q_segment_pos,
kv_segment_pos,
config,
)
if config.window_size != (-1, -1):
kv_src_rank = (cp_size + cp_rank - idx) % cp_size
# Note: all inputs of adjust_cp_striped_window_size should be host values
cp_striped_window_size = adjust_cp_striped_window_size(
cp_rank, kv_src_rank, cp_size, config.window_size
)
current_config = replace(
subblock_config, cp_striped_window_size=cp_striped_window_size
)
else:
current_config = subblock_config
output_per_step, softmax_aux_per_step, _ = compute(current_config)
softmax_aux_per_step = softmax_aux_per_step.reshape((batch, q_max_seqlen, head, 1))
def skip_correction(_output, _softmax_aux, output_per_step, softmax_aux_per_step):
# No correction done here but we cast outputs to float32 and perform reduction
# in full precision.
return output_per_step.astype(jnp.float32), softmax_aux_per_step
def correction(output, softmax_aux, output_per_step, softmax_aux_per_step):
new_out = output - jax.nn.sigmoid(softmax_aux_per_step - softmax_aux) * (
output - output_per_step
)
new_aux = softmax_aux - jax.nn.log_sigmoid(softmax_aux - softmax_aux_per_step)
return new_out, new_aux
# first step there is no correction we get initial output and stats
output, softmax_aux = lax.cond(
idx == 0,
skip_correction,
correction,
output,
softmax_aux,
output_per_step,
softmax_aux_per_step,
)
return (kv_next, kv_segment_ids_next, kv_segment_pos_next, output, softmax_aux)
carry = (kv, kv_segment_ids, kv_segment_pos, output, softmax_aux)
if helper.use_scanloop():
carry = lax.fori_loop(0, cp_size, scan_kv_block, carry)
else:
for i in range(0, cp_size):
carry = scan_kv_block(i, carry)
(_, _, _, output, softmax_aux) = carry
return output.astype(q.dtype), softmax_aux, rng_state
return mesh, fwd_impl, out_shardings, arg_shardings
register_primitive(FusedRingAttnStripedFwdPrimitive)
class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive):
"""
Fused Striped Ring Attention Backward Primitive
"""
@staticmethod
def partition(config, mesh, arg_infos, result_infos):
is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1
if not is_context_parallel:
return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos)
arg_shardings = tuple(arg.sharding for arg in arg_infos)
# dq, dk, dv, dbias sharding = q, k, v, bias sharding
out_shardings = tuple(arg.sharding for arg in arg_infos[:4])
helper = _FusedAttnCPWithP2PHelper(mesh, config)
helper.check_supported()
def bwd_impl(
q,
k,
v,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_seqlen,
kv_seqlen,
q_seq_offsets,
k_seq_offsets,
q_segment_ids,
kv_segment_ids,
q_segment_pos,
kv_segment_pos,
):
if q_segment_ids.size == 0 or kv_segment_ids.size == 0:
raise ValueError("THD + ring attn only supports passing seqment_ids/pos")
_not_used = jnp.zeros(0, dtype=output.dtype)
# Combine KV tensors if separate for better permute scheduling and performance.
# Eventually XLA should perform this automatically.
kv = helper.stack_kv(k, v)
if not config.qkv_layout.is_qkvpacked():
subblock_config = replace(config, qkv_layout=config.qkv_layout.to_kvpacked())
else:
subblock_config = config
cp_size = get_mesh_axis_size(config.cp_axis, mesh)
# We need cp_rank to be a host value for adjust_cp_striped_window_size()
cp_rank = get_mesh_axis_rank_host(config.cp_axis, mesh)
cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)]
dq = jnp.zeros_like(q)
dkv = jnp.zeros_like(kv)
dbias = jnp.zeros_like(bias)
def scan_kv_block(idx, carry):
kv, kv_segment_ids, kv_segment_pos, dq, dkv, dbias = carry
# Start communication that feeds the next iteration.
# We further combine the tensors to improve overlap.
kv_dkv = jnp.stack([kv, dkv])
kv_dkv = helper.permute_kv(kv_dkv, cp_perm)
kv_segment_ids_next = helper.permute_kv(kv_segment_ids, cp_perm)
kv_segment_pos_next = helper.permute_kv(kv_segment_pos, cp_perm)
def compute(config):
dq_per_step, dkv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl(
q,
kv,
_not_used,
bias,
softmax_aux,
rng_state,
output,
doutput,
q_seqlen,
kv_seqlen,
q_seq_offsets,
k_seq_offsets,
q_segment_ids,
kv_segment_ids,
q_segment_pos,
kv_segment_pos,
config=config,
)
return dq_per_step, dkv_per_step, dbias_per_step
if config.window_size != (-1, -1):
kv_src_rank = (cp_size + cp_rank - idx) % cp_size
# Note: all inputs of adjust_cp_striped_window_size should be host values
cp_striped_window_size = adjust_cp_striped_window_size(
cp_rank, kv_src_rank, cp_size, config.window_size
)
current_config = replace(
subblock_config, cp_striped_window_size=cp_striped_window_size
)
else:
current_config = subblock_config
dq_per_step, dkv_per_step, dbias_per_step = compute(current_config)
kv_next, dkv = jnp.unstack(kv_dkv)
dq += dq_per_step
dkv += dkv_per_step
if config.attn_bias_type is not AttnBiasType.NO_BIAS:
dbias = dbias + dbias_per_step
return (kv_next, kv_segment_ids_next, kv_segment_pos_next, dq, dkv, dbias)
carry = (kv, kv_segment_ids, kv_segment_pos, dq, dkv, dbias)
if helper.use_scanloop():
carry = lax.fori_loop(0, cp_size, scan_kv_block, carry)
else:
for idx in range(cp_size):
carry = scan_kv_block(idx, carry)
(_, _, _, dq, dkv, dbias) = carry
# Final permute to put gradients back to their final resting place.
dkv = helper.permute_kv(dkv, cp_perm)
global_dbias = dbias
if config.attn_bias_type is not AttnBiasType.NO_BIAS:
global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh)
dk, dv = helper.unstack_kv(dkv)
return dq, dk, dv, global_dbias
return mesh, bwd_impl, out_shardings, arg_shardings
register_primitive(FusedRingAttnStripedBwdPrimitive)
def _maybe_context_parallel_axis(cp_axis: str):
if not cp_axis:
gmr = global_mesh_resource()
if gmr is not None:
cp_axis = gmr.cp_resource
else:
cp_axis = ""
return cp_axis
def fused_attn_fwd(
qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray],
sequence_descriptor: SequenceDescriptor,
seed: Optional[jnp.ndarray],
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
qkv_layout: QKVLayout,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
max_segments_per_seq: int,
window_size: Optional[Tuple[int, int]] = None,
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
) -> jnp.ndarray:
"""
Perform the forward pass of with cuDNN fused attention implementations.
This function implements the following formula:
BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2
Args:
qkv (Tuple[jnp.ndarray, ...]): A tuple containing query, key, and value tensors.
It supports three formats:
- `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
and value have the same shape (e.g., self-attention).
- `(query, kv_packed)`: For separate query and KV packed format, typically used when
query has a different shape (e.g., cross-attention).
- `(query, key, value)`: For separate query, key, and value tensors.
bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
q_seqlen (jnp.ndarray): Sequence lengths for the query, with shape [batch,].
kv_seqlen (jnp.ndarray): Sequence lengths for the key and value, with shape [batch,].
q_seq_offsets (jnp.ndarray):
The offsets in the sequence dim for the query, with shape [batch + 1,].
kv_seq_offsets (jnp.ndarray):
The offsets in the sequence dim for the query, with shape [batch + 1,].
seed (Optional[jnp.ndarray]): Optional random seed for dropout.
attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask.
qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
is_training (bool): Flag indicating whether the model is in training mode.
max_segments_per_seq (int):
Indicating the maximum number of segments inside a sequence. This parameter is to
constrain the limit usage and need to be static during the e2e training. The XLA compile
time and memory consumption is proportional to `max_segments_per_seq`.
window_size (Optional[Tuple[int, int]]): Sliding window size.
context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
Returns:
(jnp.ndarray): The output tensor from the fused attention.
"""
seed = _FusedAttnRNGStateChecker().check_seed(seed, dropout_probability, is_training)
# For optional tensors, which custom calls doesn't support None
_not_used = jnp.zeros(0, dtype=qkv[0].dtype)
if qkv_layout.is_qkvpacked():
assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
qkv_for_primitive = [*qkv, _not_used, _not_used]
elif qkv_layout.is_kvpacked():
assert (
len(qkv) == 2
), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
qkv_for_primitive = [*qkv, _not_used]
elif qkv_layout.is_separate():
assert (
len(qkv) == 3
), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
qkv_for_primitive = qkv
else:
raise ValueError(f"Unknown {qkv_layout=}")
if attn_bias_type == AttnBiasType.NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=qkv[0].dtype)
fused_config = _FusedAttnConfig(
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
window_size=(-1, -1) if window_size is None else window_size,
context_parallel_load_balanced=context_parallel_causal_load_balanced,
cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
cp_striped_window_size=None,
)
primitive = None
match context_parallel_strategy:
case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER:
primitive = FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive
case CPStrategy.RING:
# We must use stripe attention for THD-RING
if qkv_layout.is_thd():
primitive = FusedRingAttnStripedFwdPrimitive.outer_primitive
else:
primitive = FusedRingAttnFwdPrimitive.outer_primitive
seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor)
output, softmax_aux, rng_state = primitive.bind(
*qkv_for_primitive,
bias,
seed,
*seq_desc_flatten,
config=fused_config,
)
rng_state = with_sharding_constraint(rng_state, PartitionSpec(get_all_mesh_axes(), None))
return (output, softmax_aux, rng_state)
def fused_attn_bwd(
qkv: Tuple[jnp.ndarray, ...],
bias: Optional[jnp.ndarray],
softmax_aux: jnp.ndarray,
rng_state: jnp.ndarray,
output: jnp.ndarray,
doutput: jnp.ndarray,
sequence_descriptor: SequenceDescriptor,
attn_bias_type: AttnBiasType,
attn_mask_type: AttnMaskType,
qkv_layout: QKVLayout,
scaling_factor: float,
dropout_probability: float,
is_training: bool,
max_segments_per_seq: int,
window_size: Optional[Tuple[int, int]] = None,
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
context_parallel_causal_load_balanced: bool = False,
context_parallel_axis: str = "",
):
"""
Perform the backward pass of the cuDNN fused attention implementations.
Args:
qkv (Tuple[jnp.ndarray, ...]): A tuple containing the original query, key, and value tensors
used in the forward pass. It supports three formats:
- `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key,
and value have the same shape (e.g., self-attention).
- `(query, kv_packed)`: For separate query and KV packed format, typically used when
query has a different shape (e.g., cross-attention).
- `(query, key, value)`: For separate query, key, and value tensors.
bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores.
softmax_aux (jnp.ndarray): Auxiliary tensors from the softmax step used in the forward pass.
rng_state (jnp.ndarray): Auxiliary tensors to save the random state in the forward pass.
output (jnp.ndarray): The output tensor from the forward pass.
doutput (jnp.ndarray): The gradient with respect to the output.
q_seqlen (jnp.ndarray): Sequence lengths for the query, with shape [batch,].
kv_seqlen (jnp.ndarray): Sequence lengths for the key and value, with shape [batch,].
q_seq_offsets (jnp.ndarray):
The offsets in the sequence dim for the query, with shape [batch + 1,].
kv_seq_offsets (jnp.ndarray):
The offsets in the sequence dim for the query, with shape [batch + 1,].
attn_bias_type (AttnBiasType): Type of attention bias.
attn_mask_type (AttnMaskType): Type of attention mask.
qkv_layout (QKVLayout): Layout of the QKV tensors.
scaling_factor (float): Scaling factor for the attention scores.
dropout_probability (float): Dropout probability to apply during attention.
is_training (bool): Flag indicating whether the model is in training mode.
max_segments_per_seq (int):
Indicating the maximum number of segments inside a sequence. This parameter is to
constrain the limit usage and need to be static during the e2e training. The XLA compile
time and memory consumption is proportional to `max_segments_per_seq`.
window_size (Optional[Tuple[int, int]]): Sliding window size .
context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
Returns:
Tuple[jnp.ndarray, ...], jnp.ndarray:
- The first tuple contains the gradients with respect to the input `qkv` tensors in the
same format as the input `qkv`.
- The second value is the gradient with respect to `bias`, or `None` if `bias` is `None`.
"""
# For optional tensors, which custom calls doesn't support None
_not_used = jnp.zeros(0, dtype=qkv[0].dtype)
if qkv_layout.is_qkvpacked():
assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}"
qkv_for_primitive = [*qkv, _not_used, _not_used]
elif qkv_layout.is_kvpacked():
assert (
len(qkv) == 2
), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}"
qkv_for_primitive = [*qkv, _not_used]
elif qkv_layout.is_separate():
assert (
len(qkv) == 3
), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}"
qkv_for_primitive = qkv
else:
raise ValueError(f"Unknown {qkv_layout=}")
if attn_bias_type == AttnBiasType.NO_BIAS:
assert bias is None
bias = jnp.zeros(0, dtype=qkv[0].dtype)
fused_config = _FusedAttnConfig(
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
qkv_layout=qkv_layout,
scaling_factor=scaling_factor,
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
window_size=(-1, -1) if window_size is None else window_size,
context_parallel_load_balanced=context_parallel_causal_load_balanced,
cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
cp_striped_window_size=None,
)
primitive = None
match context_parallel_strategy:
case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER:
primitive = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive
case CPStrategy.RING:
if qkv_layout.is_thd():
primitive = FusedRingAttnStripedBwdPrimitive.outer_primitive
else:
primitive = FusedRingAttnBwdPrimitive.outer_primitive
seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor)
*qkv_grads, bias_grad = primitive.bind(
*qkv_for_primitive,
bias,
softmax_aux,
rng_state,
output,
doutput,
*seq_desc_flatten,
config=fused_config,
)
return tuple(qkv_grads[: len(qkv)]), bias_grad
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment