Created
          September 8, 2025 21:00 
        
      - 
      
- 
        Save epwalsh/94d6a0d506dcae906419df89b0e53ab2 to your computer and use it in GitHub Desktop. 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # See LICENSE for license information. | |
| """ | |
| Wrapper module for Transformer related layers with FP8 support. | |
| """ | |
| import functools | |
| from enum import Enum | |
| from math import sqrt | |
| import os | |
| from typing import Any, Callable, Optional, Sequence, Tuple, Union | |
| import warnings | |
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| from flax import linen as nn | |
| from flax.linen.attention import combine_masks | |
| from jax import nn as jax_nn | |
| from jax import random as jax_random | |
| from jax import lax, vmap | |
| from jax.ad_checkpoint import checkpoint_name | |
| from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP | |
| from .module import LayerNorm, Softmax | |
| from ..attention import AttnBiasType, AttnMaskType, QKVLayout, SequenceDescriptor, CPStrategy | |
| from ..attention import is_fused_attn_kernel_available, make_swa_mask, canonicalize_attn_mask_type | |
| from ..attention import fused_attn | |
| from ..softmax import SoftmaxType | |
| from ..sharding import num_of_devices | |
| from ..sharding import get_sharding_map_logic_axis_to_mesh_axis | |
| from ..sharding import with_sharding_constraint_by_logical_axes | |
| from ..sharding import BATCH_AXES, SEQLEN_AXES, SEQLEN_TP_AXES, HEAD_AXES | |
| from ..sharding import HIDDEN_AXES, HIDDEN_TP_AXES, JOINED_AXES | |
| from ..sharding import W_NO_SHARD_AXES, W_FSDP_AXES, W_TP_AXES, W_JOINED_AXES | |
| PRNGKey = Any | |
| Shape = Tuple[int, ...] | |
| DType = jnp.dtype | |
| Array = jnp.ndarray | |
| PrecisionLike = Union[ | |
| None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision] | |
| ] | |
| Initializer = Callable[[PRNGKey, Shape, DType], Array] | |
| LogicalRules = Sequence[Tuple[str, Union[str, None]]] | |
| def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]: | |
| # Generate broadcast dims for drop_path. | |
| drop_path_shape = list(range(0, len(shape))) | |
| drop_path_shape.pop(batch_dim) | |
| return drop_path_shape | |
| def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules: | |
| """ | |
| Extend the given Flax logical axis rules with the predefined TransformerLayer's | |
| logical axis rules. | |
| .. note:: | |
| We currently only support logical axis rules for single GPU training, data parallel | |
| training and 1D-sharding tensor parallel training. | |
| Refer to `Figure 3 in` `Megatron-LM tensor parallel <https://arxiv.org/pdf/1909.08053.pdf>`_ | |
| for 1D-sharding tensor parallelism. | |
| .. warning:: | |
| Please make sure ShardingResource is set via fp8_autocast before calling this function. | |
| .. note:: | |
| This function is only needed when using TransformerLayer. For other modules, such as | |
| DenseGeneral, please properly set axes of kernels and bias. | |
| Parameters | |
| ---------- | |
| rules: Sequence[Tuple[str, Union[str, None]]] | |
| the base Flax logical axis rules to extend. | |
| Returns | |
| ------- | |
| extended_rules: Sequence[Tuple[str, Union[str, None]]] | |
| the extended Flax logical axis rules. | |
| """ | |
| rules_map = {} | |
| for item in rules: | |
| assert len(item) == 2, "The logical axis rule should be like (axis_name, mesh_axis_name)." | |
| key = item[0] | |
| val = item[1] | |
| assert isinstance(key, str), f"Thie axis_name should be str, but got {type(key)}." | |
| assert isinstance(val, str) or ( | |
| val is None | |
| ), f"Thie mesh_axis_name should be str or None, but got {type(val)}." | |
| if key in rules_map: | |
| rules_map[key].append(val) | |
| else: | |
| rules_map[key] = [val] | |
| extended_rules = [*rules] | |
| for item in get_sharding_map_logic_axis_to_mesh_axis().items(): | |
| key = item[0] | |
| val = item[1] | |
| if key in rules_map: | |
| assert len(rules_map[key]) == 1 and rules_map[key][0] == val, ( | |
| "The rule diverged between TE and given rule." | |
| f"Axis:{key} map to {rules_map[key]} in the given" | |
| f" rules, but {val} in TE's rules." | |
| ) | |
| else: | |
| extended_rules.append(item) | |
| return tuple(extended_rules) | |
| class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-methods | |
| attention_dropout: float = 0.0 | |
| attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK | |
| attn_bias_type: Optional[AttnBiasType] = None | |
| dtype: DType = jnp.float32 | |
| float32_logits: bool = False | |
| scale_factor: Optional[float] = None | |
| transpose_batch_sequence: bool = True | |
| window_size: Optional[Tuple[int, int]] = None | |
| @nn.compact | |
| def __call__( | |
| self, | |
| query: Array, | |
| key: Array, | |
| value: Array, | |
| mask: Optional[Array] = None, | |
| bias: Optional[Array] = None, | |
| *, | |
| dropout_rng: Optional[PRNGKey] = None, | |
| deterministic: bool = False, | |
| ) -> Array: | |
| assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank." | |
| batch_dim = 1 if self.transpose_batch_sequence else 0 | |
| assert ( | |
| query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim] | |
| ), "q, k, v batch dims must match." | |
| sequence_dim = 0 if self.transpose_batch_sequence else 1 | |
| assert key.shape[sequence_dim] == value.shape[sequence_dim], "k, v lengths must match." | |
| assert key.shape[-2] == value.shape[-2], "k, v num_attention_heads must match." | |
| assert query.shape[-1] == key.shape[-1], "q, k head_dim must match." | |
| input_dtype = query.dtype | |
| if self.scale_factor is None: | |
| scale_factor = 1.0 / sqrt(query.shape[-1]) | |
| else: | |
| scale_factor = self.scale_factor | |
| del self.scale_factor | |
| if self.float32_logits: | |
| query = query.astype(jnp.float32) | |
| key = key.astype(jnp.float32) | |
| h_q, h_kv = query.shape[-2], key.shape[-2] | |
| # The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv. | |
| # Therefore, we have to maintain two code paths. | |
| is_gqa = h_q != h_kv | |
| if is_gqa: | |
| assert (h_q % h_kv == 0) and (h_q >= h_kv) | |
| group_size = h_q // h_kv | |
| grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1])) | |
| if self.transpose_batch_sequence: | |
| if is_gqa: | |
| attn_weights = jnp.einsum("qbhgd,kbhd->bhgqk", grouped_query, key) | |
| else: | |
| attn_weights = jnp.einsum("qbhd,kbhd->bhqk", query, key) | |
| else: | |
| if is_gqa: | |
| attn_weights = jnp.einsum("bqhgd,bkhd->bhgqk", grouped_query, key) | |
| else: | |
| attn_weights = jnp.einsum("bqhd,bkhd->bhqk", query, key) | |
| attn_weights = checkpoint_name(attn_weights, "logits") | |
| if is_gqa: | |
| b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape | |
| attn_weights_without_groups_shape = (b, h * g, q, k) | |
| attn_weights = attn_weights.reshape(attn_weights_without_groups_shape) | |
| attn_weights = with_sharding_constraint_by_logical_axes( | |
| attn_weights, (BATCH_AXES, HEAD_AXES, SEQLEN_AXES, SEQLEN_AXES) | |
| ) | |
| # When post_scale_bias is present, the computation is Softmax(attn_weights * scale + bias) | |
| # In this case, the scale can not fused into the Softmax module. | |
| if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS: | |
| attn_weights = attn_weights * scale_factor | |
| fused_scale_factor = 1.0 | |
| else: | |
| # If not post_scale_bias, the scale can be fused into Softmax module | |
| fused_scale_factor = scale_factor | |
| if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS: | |
| attn_weights += bias | |
| def apply_swa_mask(original_mask: Array) -> Array: | |
| """Apply the sliding window mask to a given mask""" | |
| batch = original_mask.shape[0] | |
| max_seqlen_q = original_mask.shape[-2] | |
| max_seqlen_kv = original_mask.shape[-1] | |
| # TODO(rewang): Support THD format pos | |
| pos_q = jnp.broadcast_to(jnp.arange(max_seqlen_q), (batch, max_seqlen_q)) | |
| pos_kv = jnp.broadcast_to(jnp.arange(max_seqlen_kv), (batch, max_seqlen_kv)) | |
| # In inv_swa_mask 0 is masked out, in original_mask 1 is masked out | |
| inv_swa_mask = make_swa_mask(pos_q, pos_kv, self.window_size, original_mask.dtype) | |
| swa_mask = 1 - inv_swa_mask | |
| new_mask = jnp.where(original_mask == 0, swa_mask, original_mask) | |
| return new_mask | |
| def convert_to_softmax_type(attn_mask_type, mask): | |
| """Convert the attn_mask_type to SoftmaxType""" | |
| # mask is ignored for no_mask and causal_mask without sliding window | |
| if attn_mask_type == AttnMaskType.NO_MASK: | |
| mask = None | |
| if attn_mask_type == AttnMaskType.CAUSAL_MASK and self.window_size is None: | |
| mask = None | |
| if mask is not None: | |
| mask = apply_swa_mask(mask) | |
| # Currently cuDNN backend only supports SWA for causal/padding_causal, follow this | |
| if mask is not None: | |
| return SoftmaxType.SCALED_MASKED, mask | |
| if attn_mask_type is AttnMaskType.CAUSAL_MASK: | |
| return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask | |
| if attn_mask_type is AttnMaskType.NO_MASK: | |
| return SoftmaxType.SCALED, mask | |
| raise ValueError( | |
| f"Unsupported {attn_mask_type=}, supported attn_mask_type=" | |
| "{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}" | |
| ) | |
| softmax_type, mask = convert_to_softmax_type(self.attn_mask_type, mask) | |
| attn_weights = Softmax(softmax_type=softmax_type, scale_factor=fused_scale_factor)( | |
| attn_weights, mask, bias | |
| ).astype(input_dtype) | |
| if is_gqa: | |
| attn_weights = attn_weights.reshape(attn_weights_with_groups_shape) | |
| if not deterministic and self.attention_dropout > 0.0: | |
| keep_prob = 1.0 - self.attention_dropout | |
| dropout_shape = list(attn_weights.shape) | |
| # TODO(rewang): add attention dropout broadcast dimension arguments for users | |
| keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape) | |
| multiplier = keep.astype(input_dtype) / jnp.asarray(keep_prob, dtype=input_dtype) | |
| attn_weights = attn_weights * multiplier | |
| assert ( | |
| attn_weights.dtype == input_dtype | |
| ), f"output={attn_weights.dtype}, input={input_dtype}" | |
| if self.transpose_batch_sequence: | |
| if is_gqa: | |
| return jnp.einsum("bhgqk,kbhd->qbhgd", attn_weights, value).reshape(query.shape) | |
| return jnp.einsum("bhqk,kbhd->qbhd", attn_weights, value) | |
| if is_gqa: | |
| return jnp.einsum("bhgqk,bkhd->bqhgd", attn_weights, value).reshape(query.shape) | |
| return jnp.einsum("bhqk,bkhd->bqhd", attn_weights, value) | |
| class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-methods | |
| attention_dropout: float = 0.0 | |
| attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK | |
| attn_bias_type: Optional[AttnBiasType] = None | |
| dtype: DType = jnp.float32 | |
| qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD | |
| scale_factor: Optional[float] = None | |
| transpose_batch_sequence: bool = False | |
| window_size: Optional[Tuple[int, int]] = None | |
| max_segments_per_seq: Optional[int] = 1 | |
| context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT | |
| context_parallel_causal_load_balanced: bool = False | |
| context_parallel_axis: str = "" | |
| @nn.compact | |
| def __call__( | |
| self, | |
| query: Array, | |
| key: Array, | |
| value: Array, | |
| sequence_descriptor: Optional[SequenceDescriptor] = None, | |
| bias: Optional[Array] = None, | |
| *, | |
| dropout_rng: Optional[PRNGKey] = None, | |
| deterministic: bool = False, | |
| ) -> Array: | |
| seed = None | |
| if dropout_rng is not None: | |
| seed = jax.random.split(dropout_rng, num_of_devices()) | |
| if self.scale_factor is None: | |
| scale_factor = 1.0 / sqrt(query.shape[-1]) | |
| else: | |
| scale_factor = self.scale_factor | |
| del self.scale_factor | |
| if self.qkv_layout.is_qkvpacked(): | |
| """qkvpacked format, treat | |
| query: qkvpacked tensor, shape = [..., 3, h, d] | |
| key: ignore | |
| value: ignore | |
| """ | |
| qkv_packed = query | |
| if self.transpose_batch_sequence: | |
| qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4]) | |
| x = fused_attn( | |
| (qkv_packed,), | |
| bias, | |
| sequence_descriptor, | |
| seed, | |
| attn_mask_type=self.attn_mask_type, | |
| attn_bias_type=self.attn_bias_type, | |
| qkv_layout=self.qkv_layout, | |
| scaling_factor=scale_factor, | |
| dropout_probability=self.attention_dropout, | |
| is_training=not deterministic, | |
| window_size=self.window_size, | |
| max_segments_per_seq=self.max_segments_per_seq, | |
| context_parallel_strategy=self.context_parallel_strategy, | |
| context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, | |
| context_parallel_axis=self.context_parallel_axis, | |
| ) | |
| elif self.qkv_layout.is_kvpacked(): | |
| """kvpacked format, treat | |
| query: query tensor, shape = [..., h, d] | |
| key: kvpacked tensor, shape = [..., 2, h, d] | |
| value: ignore | |
| """ | |
| kv_packed = key | |
| if self.transpose_batch_sequence: | |
| query = query.transpose([1, 0, 2, 3]) | |
| kv_packed = kv_packed.transpose([1, 0, 2, 3, 4]) | |
| x = fused_attn( | |
| (query, kv_packed), | |
| bias, | |
| sequence_descriptor, | |
| seed, | |
| attn_mask_type=self.attn_mask_type, | |
| attn_bias_type=self.attn_bias_type, | |
| qkv_layout=self.qkv_layout, | |
| scaling_factor=scale_factor, | |
| dropout_probability=self.attention_dropout, | |
| is_training=not deterministic, | |
| window_size=self.window_size, | |
| max_segments_per_seq=self.max_segments_per_seq, | |
| context_parallel_strategy=self.context_parallel_strategy, | |
| context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, | |
| context_parallel_axis=self.context_parallel_axis, | |
| ) | |
| elif self.qkv_layout.is_separate(): | |
| if self.transpose_batch_sequence: | |
| query = query.transpose([1, 0, 2, 3]) | |
| key = key.transpose([1, 0, 2, 3]) | |
| value = value.transpose([1, 0, 2, 3]) | |
| x = fused_attn( | |
| (query, key, value), | |
| bias, | |
| sequence_descriptor, | |
| seed, | |
| attn_mask_type=self.attn_mask_type, | |
| attn_bias_type=self.attn_bias_type, | |
| qkv_layout=self.qkv_layout, | |
| scaling_factor=scale_factor, | |
| dropout_probability=self.attention_dropout, | |
| is_training=not deterministic, | |
| window_size=self.window_size, | |
| max_segments_per_seq=self.max_segments_per_seq, | |
| context_parallel_strategy=self.context_parallel_strategy, | |
| context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, | |
| context_parallel_axis=self.context_parallel_axis, | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported {self.qkv_layout=}.") | |
| if self.transpose_batch_sequence: | |
| x = x.transpose([1, 0, 2, 3]) | |
| assert x.dtype == query.dtype | |
| return x | |
| class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods | |
| r""" | |
| Dot Product Attention (DPA). Allows the model to jointly attend to information from different | |
| representation subspaces as described in the paper: | |
| `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_. | |
| .. note:: | |
| The DotProductAttention module supports two backends: the unfused and the fused attention | |
| mechanisms. The unfused attention is implemented using JAX native operations, providing | |
| broad compatibility and flexibility. In contrast, the fused attention uses `cuDNN fused | |
| attention | |
| <https://github.com/NVIDIA/cudnn-frontend/blob/main/docs/operations/Attention.md>`_ for | |
| higher performance and lower memory usage on the supported hardwares. | |
| Users can select between these two backends via the :attr:`NVTE_FUSED_ATTN` environment | |
| variable: | |
| * Set :attr:`NVTE_FUSED_ATTN=0` for unfused attention (default). | |
| * Set :attr:`NVTE_FUSED_ATTN=1` for fused attention. If the required cuDNN fused attention | |
| kernel is not available on the system, a warning will be issued, and the module will | |
| automatically fall back to the unfused backend. | |
| .. note:: | |
| The DotProductAttention default setting enables non-deterministic kernels for reduced | |
| workspace requirements and faster computation. Users can disable the non-deterministic | |
| kernels via the :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO` environment variable: | |
| * :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0` to allow only deterministic kernels. | |
| * :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=1` to allow non-deterministic kernels (default). | |
| Parameters | |
| ---------- | |
| head_dim: int | |
| The hidden dimension of each attention head. | |
| num_attention_heads: int | |
| The number of attention heads. | |
| num_gqa_groups: int, default = `None` | |
| Number of GQA groups. When `None` is present, it is equal to num_attention_heads. | |
| Grouped Query Attention is described in | |
| `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_. | |
| This only affects the keys and values, not the querys. | |
| GQA-1 is equivalent to Multi-Query Attention | |
| (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H | |
| is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`. | |
| attention_dropout: float, default = 0.0 | |
| Dropout probability for the dropout op after the softmax. | |
| attn_mask_type: str, default = 'causal' | |
| This parameter specifies the type of attention mask to be applied during the softmax | |
| operation. | |
| Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'} | |
| Each described below: | |
| * no_mask: No attention mask is applied. This means the attention will consider the | |
| full sequence without any restrictions. | |
| * padding: Indicates the presence of padding at the end of each sequence. | |
| Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the | |
| :attr:`__call__` method to specify the padding positions. | |
| * causal: An upper triangular mask is applied to the softmax inputs, | |
| ensuring that the prediction for a certain position is only dependent on known outputs | |
| from positions before it. | |
| * causal_padding / padding_causal: A combination of both causal and padding masks. | |
| Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect. | |
| .. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'. | |
| .. note:: THD format only supports 'padding' or 'causal_padding' mask type. | |
| attn_mask_type mask/sequence_descriptor SWA softmax type | |
| -------------------------------------------------------------------------------------------- | |
| no_mask None None SCALED | |
| causal None None SCALED_UPPER_TRIANG_MASKED | |
| causal None Yes SCALED_MASKED | |
| padding Required Yes/No SCALED_MASKED | |
| padding_causal Required Yes/No SCALED_MASKED | |
| attn_bias_type: Optional[str], default = None | |
| Type of the attention bias passed in the attention. | |
| Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}. | |
| When default is present, the type is automatically decided by the MHA's bias parameter. | |
| Where it is :attr:`post_scale_bias` if there is bias. Otherwise :attr:`no_bias` is used. | |
| dropout_rng_name: str, default = 'dropout' | |
| The key in given RNGs via flax.linen.Module.apply that is used | |
| to generate Dropout masks in the core attention. | |
| float32_logits: bool, default = False | |
| Whether to compute attention logits in float32 for the unfused attention backend. | |
| For fused attention backend, the accumulation is always float32 without the perf overhead. | |
| qkv_layout: str, default = 'bshd_bshd_bshd' | |
| Specifies the dimensional layout format for the query, key, and value tensors in __call__(). | |
| It indicates how the inputs are processed. | |
| Available options: {'bs3hd', 'bshd_bs2hd', 'bshd_bshd_bshd', 't3hd', 'thd_t2hd', 'thd_thd_thd'}. | |
| * bs3hd: query tensor is treated as a qkvpacked tensor with shape = [b, s, 3, h, d]. | |
| key and value arguments in :attr:`__call__()` are ignored in this layout. | |
| * bshd_bs2hd: query tensor with shape = [b, s, h, d]. key tensor is treaded as a kvpacked | |
| tensor with shape = [b, s, 2, h, d]. `value` argument in :attr:`__call__()` is ignored. | |
| * bshd_bshd_bshd: query, key, and value are seperated with shape = [b, s, h, d]. | |
| * t3hd/thd_t2hd/thd_thd_thd: Have the same layout as bshd series, but it allows multiple | |
| sequences to be packed in a batch, also known as sequence packing. | |
| Explanation of denotations: | |
| * b: batch size | |
| * s: seqeuence length | |
| * h: num_attention_heads or num_gqa_groups | |
| * d: head dimension | |
| scale_factor: Optional[float], default = None | |
| Scale factor to apply on query. When :attr:`None` is present, the scale factor is equal | |
| to :math:`\frac{1}{\sqrt{head\_dim}}`. This is useful for model like T5X, which doesn't | |
| need to apply scale on query, which is to set :attr:`scale_factor=1.`. | |
| transpose_batch_sequence: bool, default = True | |
| Indicate whether the input tensors were switched axis of batch | |
| and sequence length dimension. if set to True, the input tensors | |
| should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...). | |
| window_size: Optional[Tuple[int, int]], default = None | |
| Sliding window size. The default value is no sliding window. | |
| max_segments_per_seq: Optional[int], default = 1 | |
| The maximum number of segments per sequence, also used for THD format (sequence packing). | |
| context_parallel_causal_load_balanced (bool): | |
| Indicates the sequences are ordered for causal mask load balancing when running context parallelism. | |
| context_parallel_axis (str): The name of the context parallel axis. | |
| Optimization parameters | |
| ----------------------- | |
| dtype: jax.numpy.dtype, default = jax.numpy.float32 | |
| The data type used to allocate the initial parameters. | |
| """ | |
| head_dim: int | |
| num_attention_heads: int | |
| num_gqa_groups: Optional[int] = None | |
| attention_dropout: float = 0.0 | |
| attn_mask_type: AttnMaskType = "causal" | |
| attn_bias_type: AttnBiasType = None | |
| dtype: DType = jnp.float32 | |
| dropout_rng_name: str = "dropout" | |
| float32_logits: bool = False | |
| qkv_layout: str = "bshd_bshd_bshd" | |
| scale_factor: Optional[float] = None | |
| transpose_batch_sequence: bool = True | |
| window_size: Optional[Tuple[int, int]] = None | |
| max_segments_per_seq: Optional[int] = 1 | |
| context_parallel_strategy: str = "default" | |
| context_parallel_causal_load_balanced: bool = False | |
| context_parallel_axis: str = "" | |
| @nn.compact | |
| def __call__( | |
| self, | |
| query: Array, | |
| key: Array, | |
| value: Array, | |
| sequence_descriptor: Optional[Union[SequenceDescriptor, Array]] = None, | |
| bias: Optional[Array] = None, | |
| *, | |
| deterministic: bool = False, | |
| mask: Optional[Union[SequenceDescriptor, Array]] = None, | |
| ) -> Array: | |
| """ | |
| Parameters | |
| ---------- | |
| query: jax.numpy.ndarray | |
| The details of query tensor representation is described in :attr:`qkv_layout`. | |
| key: jax.numpy.ndarrary | |
| The details of kery tensor representation is described in :attr:`qkv_layout`. | |
| value: jax.numpy.ndarrary | |
| The details of value tensor representation is described in :attr:`qkv_layout`. | |
| mask: jax.numpy.ndarray, default = None | |
| Boolean tensor used to mask out the attention softmax input. | |
| :attr:`True` means to mask out the corresponding values. | |
| Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'. | |
| bias: jax.numpy.ndarray, default = None | |
| A tensor used to shift attention softmax input. | |
| *: | |
| Below parameters are keyword only | |
| deterministic: bool, default = False | |
| Disable dropout layers if set to True. | |
| Returns | |
| ------- | |
| outputs: jax.numpy.ndarray | |
| Output tensors. | |
| """ | |
| input_dtype = query.dtype | |
| if mask is not None: | |
| if sequence_descriptor is not None: | |
| raise ValueError( | |
| "sequence_descriptor and mask cannot be provided at the same time." | |
| ) | |
| warnings.warn("mask is deprecated, please use sequence_descriptor instead.") | |
| sequence_descriptor = mask | |
| del mask | |
| # For internal API, we use enum to maintain | |
| if self.attn_bias_type is None: | |
| attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS | |
| else: | |
| attn_bias_type = AttnBiasType[self.attn_bias_type.upper()] | |
| attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type) | |
| qkv_layout = QKVLayout[self.qkv_layout.upper()] | |
| del self.attn_bias_type, self.attn_mask_type, self.qkv_layout | |
| if attn_bias_type == AttnBiasType.NO_BIAS: | |
| assert bias is None | |
| else: | |
| assert bias is not None | |
| enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0")) | |
| sequence_dim = 0 if self.transpose_batch_sequence else 1 | |
| seqlen_q = query.shape[sequence_dim] | |
| if qkv_layout == QKVLayout.BS3HD: | |
| seqlen_kv = seqlen_q | |
| else: | |
| seqlen_kv = key.shape[sequence_dim] | |
| if qkv_layout.is_separate(): | |
| head_dim_qk = query.shape[-1] | |
| head_dim_v = value.shape[-1] | |
| else: | |
| head_dim_qk = self.head_dim | |
| head_dim_v = self.head_dim | |
| has_fused_attn_kernel = is_fused_attn_kernel_available( | |
| # This needs to be fixed: TE-Jax has historically correlated training mode with deterministic mode. | |
| not deterministic, | |
| self.dtype, | |
| self.dtype, | |
| qkv_layout, | |
| attn_bias_type, | |
| attn_mask_type, | |
| self.attention_dropout, | |
| self.num_attention_heads, | |
| self.num_gqa_groups, | |
| seqlen_q, | |
| seqlen_kv, | |
| head_dim_qk, | |
| head_dim_v, | |
| self.window_size, | |
| ) | |
| use_fused_attn = enable_fused_attn and has_fused_attn_kernel | |
| if enable_fused_attn and not has_fused_attn_kernel: | |
| warnings.warn( | |
| "Fused attention is not enabled because there is no available kernel.\n" | |
| "Fall back to the unfused attention.\n" | |
| "Please try to update the cuDNN and TE to the latest version.\n" | |
| f"{self.dtype=}\n{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n" | |
| f"{self.attention_dropout=}\n{self.num_attention_heads=}\n" | |
| f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{head_dim_qk=}\n{head_dim_v=}\n" | |
| ) | |
| dropout_rng = None | |
| if not deterministic and self.attention_dropout > 0.0: | |
| dropout_rng = self.make_rng(self.dropout_rng_name) | |
| if self.scale_factor is None: | |
| scale_factor = 1.0 / sqrt(head_dim_qk) | |
| else: | |
| scale_factor = self.scale_factor | |
| del self.scale_factor | |
| if not use_fused_attn: | |
| # unfused attention only supports splitted query, key, value | |
| if qkv_layout.is_qkvpacked(): | |
| query, key, value = jnp.split(query, [1, 2], axis=-3) | |
| query, key, value = map( | |
| functools.partial(jnp.squeeze, axis=-3), [query, key, value] | |
| ) | |
| elif qkv_layout.is_kvpacked(): | |
| key, value = jnp.split(key, [1], axis=-3) | |
| key, value = map(functools.partial(jnp.squeeze, axis=-3), [key, value]) | |
| else: | |
| assert qkv_layout.is_separate() | |
| assert sequence_descriptor is None or isinstance( | |
| sequence_descriptor, (jnp.ndarray, np.ndarray) | |
| ) | |
| x = _UnfusedDotProductAttention( | |
| attention_dropout=self.attention_dropout, | |
| attn_mask_type=attn_mask_type, | |
| attn_bias_type=attn_bias_type, | |
| dtype=self.dtype, | |
| float32_logits=self.float32_logits, | |
| scale_factor=scale_factor, | |
| transpose_batch_sequence=self.transpose_batch_sequence, | |
| window_size=self.window_size, | |
| )( | |
| query, | |
| key, | |
| value, | |
| sequence_descriptor, | |
| bias, | |
| dropout_rng=dropout_rng, | |
| deterministic=deterministic, | |
| ) | |
| else: | |
| x = _FusedDotProductAttention( | |
| attention_dropout=self.attention_dropout, | |
| attn_mask_type=attn_mask_type, | |
| attn_bias_type=attn_bias_type, | |
| dtype=self.dtype, | |
| scale_factor=scale_factor, | |
| transpose_batch_sequence=self.transpose_batch_sequence, | |
| qkv_layout=qkv_layout, | |
| window_size=self.window_size, | |
| max_segments_per_seq=self.max_segments_per_seq, | |
| context_parallel_strategy=CPStrategy[self.context_parallel_strategy.upper()], | |
| context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, | |
| context_parallel_axis=self.context_parallel_axis, | |
| )( | |
| query, | |
| key, | |
| value, | |
| sequence_descriptor, | |
| bias, | |
| dropout_rng=dropout_rng, | |
| deterministic=deterministic, | |
| ) | |
| assert x.dtype == input_dtype, f"output_dtype={x.dtype}, input_dtype={input_dtype}" | |
| return x | |
| def rotary_pos_emb( | |
| x: Array, | |
| windows: Tuple[int, int], | |
| transpose_batch_sequence: bool, | |
| group_method: str = "consecutive", | |
| ): | |
| """ | |
| Rotary Positional Embedding | |
| x should be in shape of | |
| [Batch, Seqlen, ..., Heads, Hidden] if transpose_batch_sequence is False, or | |
| [Seqlen, Batch, ..., Heads, Hidden] if transpose_batch_sequence is True. | |
| """ | |
| hidden_dim = x.shape[-1] | |
| half_hidden_dim = hidden_dim // 2 | |
| min_window = windows[0] | |
| max_window = windows[1] | |
| fraction = 2 * jnp.arange(0, half_hidden_dim) / hidden_dim | |
| time_scales = min_window * (max_window / min_window) ** fraction | |
| time_scales = jnp.expand_dims(time_scales, axis=tuple(range(x.ndim - 1))) | |
| batch_dim = 1 if transpose_batch_sequence else 0 | |
| seq_dim = 1 - batch_dim | |
| positions = jnp.expand_dims(jnp.arange(x.shape[seq_dim]), axis=batch_dim) | |
| positions = jnp.expand_dims(positions, axis=tuple(range(2, x.ndim))) | |
| def generate_sin_cos(timescales): | |
| sinusoidal_positions = positions / timescales | |
| sin = jnp.sin(sinusoidal_positions) | |
| cos = jnp.cos(sinusoidal_positions) | |
| return sin, cos | |
| def alternate_impl(): | |
| sin, cos = generate_sin_cos(time_scales) | |
| x1, x2 = jnp.split(x, 2, axis=-1) | |
| part_1 = (x1 * cos - x2 * sin).astype(dtype=x.dtype) | |
| part_2 = (x2 * cos + x1 * sin).astype(dtype=x.dtype) | |
| output = jnp.concatenate([part_1, part_2], axis=-1, dtype=x.dtype) | |
| return output | |
| def consecutive_impl(): | |
| sin, cos = generate_sin_cos(jnp.repeat(time_scales, 2, axis=-1)) | |
| x_shifted_left = jnp.roll(x, -1, axis=-1) | |
| x_shifted_right = jnp.roll(x, 1, axis=-1) | |
| x_shifted = jax.lax.select( | |
| jnp.tile( | |
| jnp.mod(jnp.arange(hidden_dim, dtype=jnp.int32), 2), | |
| x.shape[:-1] + (1,), | |
| ), | |
| x_shifted_right, | |
| x_shifted_left, | |
| ) | |
| sign = jnp.sign(jnp.mod(jnp.arange(hidden_dim, dtype=jnp.int32), 2) - 0.5) | |
| output = x * cos + x_shifted * sin * sign | |
| output = output.astype(x.dtype) | |
| return output | |
| def canonicalize_group_method(gm): | |
| canonicalized_gm = gm.lower().strip().replace("-", "").replace("_", "") | |
| assert canonicalized_gm in ["consecutive", "alternate"], ( | |
| "Invalid relative positional embedding group method. " | |
| f"Expect to be in []'alternate' or 'consecutive'], but got {gm}." | |
| ) | |
| return canonicalized_gm | |
| group_method = canonicalize_group_method(group_method) | |
| if group_method == "alternate": | |
| return alternate_impl() | |
| return consecutive_impl() | |
| class LoRAScope: # pylint: disable=too-few-public-methods | |
| """LoRA Scope""" | |
| def __init__(self, qkv_proj=False, output_proj=False, mlp=False): | |
| self.qkv_proj = qkv_proj | |
| self.output_proj = output_proj | |
| self.mlp = mlp | |
| def __eq__(self, other): | |
| return (self.qkv_proj, self.output_proj, self.mlp) == ( | |
| other.qkv_proj, | |
| other.output_proj, | |
| other.mlp, | |
| ) | |
| def _canonicalize_lora_scope(scope): | |
| SCOPE_NONE = "none" | |
| SCOPE_ALL = "all" | |
| SCOPE_QKV_PROJ = "qkv_proj" | |
| SCOPE_OUTPUT_PROJ = "output_proj" | |
| SCOPE_MLP = "mlp" | |
| SCOPE_EX_QKV_PROJ = "exclude_qkv_proj" | |
| SCOPE_EX_OUTPUT_PROJ = "exclude_output_proj" | |
| SCOPE_EX_MLP = "exclude_mlp" | |
| scope = SCOPE_NONE if scope is None else scope | |
| scope = scope.lower() | |
| assert scope in [ | |
| SCOPE_NONE, | |
| SCOPE_ALL, | |
| SCOPE_QKV_PROJ, | |
| SCOPE_OUTPUT_PROJ, | |
| SCOPE_MLP, | |
| SCOPE_EX_QKV_PROJ, | |
| SCOPE_EX_OUTPUT_PROJ, | |
| SCOPE_EX_MLP, | |
| ] | |
| lora_scope = LoRAScope() | |
| if scope in [SCOPE_ALL, SCOPE_QKV_PROJ, SCOPE_EX_OUTPUT_PROJ, SCOPE_EX_MLP]: | |
| lora_scope.qkv_proj = True | |
| if scope in [SCOPE_ALL, SCOPE_OUTPUT_PROJ, SCOPE_EX_QKV_PROJ, SCOPE_EX_MLP]: | |
| lora_scope.output_proj = True | |
| if scope in [SCOPE_ALL, SCOPE_MLP, SCOPE_EX_QKV_PROJ, SCOPE_EX_OUTPUT_PROJ]: | |
| lora_scope.mlp = True | |
| return lora_scope | |
| class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods | |
| r""" | |
| Multi-head Attention (MHA), including Query, | |
| Key, Value and Output projection. | |
| Parameters | |
| ---------- | |
| head_dim: int | |
| The hidden dimension of each attention head. | |
| num_attention_heads: int | |
| The number of attention heads. | |
| num_gqa_groups: int, default = `None` | |
| Number of GQA groups. When `None` is present, it is equal to num_attention_heads. | |
| Grouped Query Attention is described in | |
| `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_. | |
| This only affects the keys and values, not the querys. | |
| GQA-1 is equivalent to Multi-Query Attention | |
| (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H | |
| is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`. | |
| attention_dropout: float, default = 0.0 | |
| Dropout probability for the dropout op after the softmax. | |
| attn_mask_type: str, default = 'causal' | |
| This parameter specifies the type of attention mask to be applied during the softmax | |
| operation. | |
| Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'} | |
| Each described below: | |
| * no_mask: No attention mask is applied. This means the attention will consider the | |
| full sequence without any restrictions. | |
| * padding: Indicates the presence of padding at the end of each sequence. | |
| Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the | |
| :attr:`__call__` method to specify the padding positions. | |
| * causal: An upper triangular mask is applied to the softmax inputs, | |
| ensuring that the prediction for a certain position is only dependent on known outputs | |
| from positions before it. | |
| * causal_padding / padding_causal: A combination of both causal and padding masks. | |
| Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect. | |
| .. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'. | |
| attn_bias_type: Optional[str], default = None | |
| Type of the attention bias passed in the attention. | |
| Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}. | |
| When default is present, the type is automatically decided by the MHA's bias parameter. | |
| Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used. | |
| dropout_rng_name: str, default = 'dropout' | |
| The key in given RNGs via flax.linen.Module.apply that is used | |
| to generate Dropout masks in the core attention. | |
| layernorm_type: {'layernorm', 'rmsnorm'}, default = 'layernorm' | |
| Indicate the type of layer normalization. | |
| layernorm_epsilon: float, default = 1e-6 | |
| A value added to the denominator of layer normalization for numerical stability. | |
| zero_centered_gamma: bool, default = False | |
| If set to `True`, the LayerNorm formula changes to | |
| .. math:: | |
| y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * | |
| (1 + \gamma) + \beta | |
| This parameter is only applicable for 'layernorm'. | |
| kernel_init: Initializer, default = | |
| flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal') | |
| Used for initializing the QKV and output projection weights. | |
| It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). | |
| use_bias: bool, default = False | |
| Indicate whether or not to enable bias shifting for QKV and output projections. | |
| If set to False, the layer will not learn additive biases. | |
| bias_init: Initializer, default = flax.linen.initializers.zeros | |
| Used for initializing bias of QKVO projections, only used when :attr:`use_bias=True`. | |
| It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). | |
| input_layernorm: bool, default = True | |
| If set to False, layer normalization to the input is not applied. | |
| return_layernorm_output: bool, default = False | |
| If set to True, output of layernorm is returned from the forward together with the output | |
| of the linear transformation. | |
| Example use case: residual connection for transformer module is taken post layernorm. | |
| enable_rotary_pos_emb: bool, default = False | |
| Whether to enable rotary position embedding to projected query and key. | |
| rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000) | |
| Indicate the min and max time-scales of rotary position embedding, | |
| only used when :attr:`enable_rotary_pos_emb=True` | |
| rotary_pos_emb_group_method: str, default = 'consecutive' | |
| Indicate the method to coupled the coordinates. It should be one of | |
| ['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2` | |
| , d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`. | |
| low_rank_adaptation_scope: str, default = 'none' | |
| Indicate the scope to apply low rank adaptation. It should be one of | |
| ['none', 'all', 'qkv_proj', 'output_proj', 'exclude_qkv_proj', 'exclude_output_proj'] | |
| low_rank_adaptation_dim: int, default = 32 | |
| The dimension for low rank adaptation, only used when | |
| :attr:`enable_low_rank_adaptation=True` | |
| low_rank_adaptation_alpha: float, default = None | |
| The alpha for computing the scaling factor of LoRA output. | |
| :math:`\frac{alpha}{rank} * lora_output`. None means no scaling. | |
| enable_sequence_parallel: bool, default = False | |
| Whether to enable sequence parallelism to operations except dot. | |
| num_heads: int, default = None | |
| Deprecated. Please refer `num_attention_heads`. | |
| dropout_rate: float, default = None | |
| Deprecated. Please refer `attention_dropout`. | |
| output_layernorm: bool, default = None | |
| Deprecated. Please refer `input_layernorm` | |
| apply_residual_connection_post_layernorm: bool, default = None | |
| Deprecated. Please refer `return_layernorm_output`. | |
| Optimization parameters | |
| ----------------------- | |
| dtype: jax.numpy.dtype, default = jax.numpy.float32 | |
| The data type used to allocate the initial parameters. | |
| fuse_qkv_params: bool, default = True | |
| If set to True, this module exposes a single fused | |
| parameter for query-key-value for self-attention and key-value for | |
| cross-attention. | |
| transpose_batch_sequence: bool, default = True | |
| Indicate whether the input tensors were switched axis of batch | |
| and sequence length dimension. if set to True, the input tensors | |
| should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). | |
| scale_attn_logits: bool, default = False | |
| Indicate whether to scale attention logits. | |
| If set to True, :math:`\frac{Q}{\sqrt{head\_dim}*K}`, | |
| else :math:`Q*K` | |
| scaled_query_init: bool, default = True | |
| Whether to scale WQ on initialization by :math:`\frac{1}{\sqrt{head\_dim}}` | |
| float32_logits: bool, default = False | |
| Whether to compute attention logits in float32 for the unfused attention backend. | |
| For fused attention backend, the accumulation is always float32 without the perf overhead. | |
| fuse_qkv: bool, default = None | |
| Deprecated. Please refer `fuse_qkv_params` | |
| window_size: Optional[Tuple[int, int]], default = None | |
| Sliding window size. Default value is no sliding window. | |
| """ | |
| head_dim: int | |
| num_attention_heads: int | |
| num_gqa_groups: Optional[int] = None | |
| attention_dropout: float = 0.0 | |
| dropout_rng_name: str = "dropout" | |
| input_layernorm: bool = True | |
| layernorm_type: str = "layernorm" | |
| layernorm_epsilon: float = 1e-6 | |
| return_layernorm_output: bool = False | |
| zero_centered_gamma: bool = False | |
| kernel_init: Initializer = None | |
| use_bias: bool = False | |
| bias_init: Initializer = nn.initializers.zeros | |
| attn_mask_type: str = "causal" | |
| attn_bias_type: Optional[str] = None | |
| enable_rotary_pos_emb: bool = False | |
| rotary_pos_emb_windows: Tuple[int, int] = (1, 10000) | |
| rotary_pos_emb_group_method: str = "consecutive" | |
| low_rank_adaptation_scope: str = "none" | |
| low_rank_adaptation_dim: int = 32 | |
| low_rank_adaptation_alpha: float = None | |
| dtype: DType = jnp.float32 | |
| fuse_qkv_params: bool = True | |
| transpose_batch_sequence: bool = True | |
| enable_sequence_parallel: bool = False | |
| scale_attn_logits: bool = False | |
| scaled_query_init: bool = True | |
| float32_logits: bool = False | |
| window_size: Optional[Tuple[int, int]] = None | |
| # Deprecated parameters | |
| num_heads: Optional[int] = None | |
| dropout_rate: Optional[float] = None | |
| output_layernorm: Optional[bool] = None | |
| apply_residual_connection_post_layernorm: Optional[bool] = None | |
| fuse_qkv: Optional[bool] = None | |
| def __post_init__(self): | |
| # Deal with the deprecated parameters | |
| if self.num_heads is not None: | |
| self.num_attention_heads = self.num_heads | |
| warnings.warn( | |
| f"{__class__}.num_heads is deprecated. It will be removed recently. " | |
| f"Please uses {__class__}.num_attention_heads as the new API.", | |
| DeprecationWarning, | |
| ) | |
| if self.dropout_rate is not None: | |
| self.attention_dropout = self.dropout_rate | |
| warnings.warn( | |
| f"{__class__}.dropout_rate is deprecated. It will be removed recently. " | |
| f"Please use {__class__}.attention_dropout as the new API.", | |
| DeprecationWarning, | |
| ) | |
| if self.apply_residual_connection_post_layernorm is not None: | |
| warnings.warn( | |
| f"{__class__}.apply_residual_connection_post_layernorm is deprecated. " | |
| f"It will be removed recently, please use {__class__}.return_layernorm_output.", | |
| DeprecationWarning, | |
| ) | |
| if self.fuse_qkv is not None: | |
| warnings.warn( | |
| f"{__class__}.fuse_qkv is deprecated. It will be removed recently. " | |
| f"Please use {__class__}.fuse_qkv_params as the new API.", | |
| DeprecationWarning, | |
| ) | |
| assert self.output_layernorm is None, ( | |
| f"{__class__}.output_layernorm is deprecated. It will be removed recently. " | |
| f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm." | |
| ) | |
| if self.kernel_init is None: | |
| self.kernel_init = nn.initializers.variance_scaling( | |
| 1.0, "fan_in", "normal", dtype=self.dtype | |
| ) | |
| if self.num_gqa_groups is None: | |
| self.num_gqa_groups = self.num_attention_heads | |
| super().__post_init__() | |
| @nn.compact | |
| def __call__( | |
| self, | |
| inputs_q: Array, | |
| inputs_kv: Array, | |
| mask: Optional[Array] = None, | |
| bias: Optional[Array] = None, | |
| *, | |
| decode: bool = False, | |
| deterministic: bool = False, | |
| ) -> Array: | |
| """ | |
| MultiHeadAttention Layer: | |
| [Query, Key, Value projection] -> Dot Product Attention -> Output projection. | |
| Parameters | |
| ---------- | |
| inputs_q: jax.numpy.ndarray | |
| Input tensor for query projection. | |
| inputs_kv: jax.numpy.ndarray | |
| Input tensor for key/value projection. | |
| mask: jax.numpy.ndarray, default = None | |
| Boolean tensor used to mask out the attention softmax input. | |
| :attr:`True` means mask out the corresponding values. | |
| Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'. | |
| bias: jax.numpy.ndarray, default = None | |
| A tensor used to shift the attention softmax input. | |
| * | |
| decode: bool, default = False | |
| Indicate whether to prepare and use an autoregressive cache. | |
| deterministic: bool, default = False | |
| Disable dropout layers if set to True. | |
| Returns | |
| ------- | |
| outputs: jax.numpy.ndarray | |
| Output tensors. | |
| """ | |
| assert ( | |
| inputs_q.dtype == inputs_kv.dtype | |
| ), f"q.dtype = {inputs_q.dtype}, kv.dtype = {inputs_kv.dtype}" | |
| input_dtype = inputs_q.dtype | |
| def query_init(*args): | |
| depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) | |
| return self.kernel_init(*args) / (depth_scaling if self.scaled_query_init else 1.0) | |
| def qkv_init(key, shape, dtype): | |
| assert len(shape) == 3 | |
| assert shape[-2] == 3 | |
| q_key, k_key, v_key = jax_random.split(key, num=3) | |
| q_shape = (shape[0], shape[-1]) | |
| k_shape = (shape[0], shape[-1]) | |
| v_shape = (shape[0], shape[-1]) | |
| q_kernel = query_init(q_key, q_shape, dtype) | |
| k_kernel = self.kernel_init(k_key, k_shape, dtype) | |
| v_kernel = self.kernel_init(v_key, v_shape, dtype) | |
| return jnp.stack([q_kernel, k_kernel, v_kernel], axis=-2, dtype=dtype) | |
| def kv_init(key, shape, dtype): | |
| assert len(shape) == 3 | |
| assert shape[-2] == 2 | |
| k_key, v_key = jax_random.split(key) | |
| k_shape = (shape[0], shape[-1]) | |
| v_shape = (shape[0], shape[-1]) | |
| k_kernel = self.kernel_init(k_key, k_shape, dtype) | |
| v_kernel = self.kernel_init(v_key, v_shape, dtype) | |
| return jnp.stack([k_kernel, v_kernel], axis=-2, dtype=dtype) | |
| def generate_batch_seqlen_logical_axes(is_sharded_seq): | |
| sequence_dim = 0 if self.transpose_batch_sequence else 1 | |
| batch_dim = 1 - sequence_dim | |
| axes = [None, None] | |
| axes[batch_dim] = BATCH_AXES | |
| axes[sequence_dim] = SEQLEN_TP_AXES if is_sharded_seq else SEQLEN_AXES | |
| return tuple(axes) | |
| is_self_attn = inputs_q is inputs_kv | |
| is_gqa = self.num_attention_heads != self.num_gqa_groups | |
| is_qkvpack = is_self_attn and not is_gqa | |
| inputs_logical_axes_maybe_sp = ( | |
| *generate_batch_seqlen_logical_axes(self.enable_sequence_parallel), | |
| HIDDEN_AXES, | |
| ) | |
| inputs_logical_axes_no_sp = (*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES) | |
| inputs_q = with_sharding_constraint_by_logical_axes(inputs_q, inputs_logical_axes_maybe_sp) | |
| lora_scope = _canonicalize_lora_scope(self.low_rank_adaptation_scope) | |
| if self.fuse_qkv_params: | |
| if is_qkvpack: | |
| qkv_proj, ln_out = LayerNormDenseGeneral( | |
| enable_layernorm=self.input_layernorm, | |
| layernorm_type=self.layernorm_type, | |
| zero_centered_gamma=self.zero_centered_gamma, | |
| epsilon=self.layernorm_epsilon, | |
| axis=-1, | |
| features=(3, self.num_attention_heads * self.head_dim), | |
| transpose_batch_sequence=self.transpose_batch_sequence, | |
| return_layernorm_output=self.return_layernorm_output, | |
| scale_axes=(W_NO_SHARD_AXES,), | |
| ln_bias_axes=(W_NO_SHARD_AXES,), | |
| kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES), | |
| kernel_init=qkv_init, | |
| use_bias=self.use_bias, | |
| bias_init=self.bias_init, | |
| bias_axes=(W_JOINED_AXES, W_TP_AXES), | |
| enable_low_rank_adaptation=lora_scope.qkv_proj, | |
| low_rank_adaptation_dim=self.low_rank_adaptation_dim, | |
| low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, | |
| layernorm_input_axes=inputs_logical_axes_maybe_sp, | |
| dot_input_axes=inputs_logical_axes_no_sp, | |
| name="qkv", | |
| dtype=self.dtype, | |
| )(inputs_q) | |
| qkv_proj = checkpoint_name(qkv_proj, "combined_qkv_proj") | |
| qkv_layout = QKVLayout.BS3HD | |
| else: | |
| query, ln_out = LayerNormDenseGeneral( | |
| enable_layernorm=self.input_layernorm, | |
| layernorm_type=self.layernorm_type, | |
| zero_centered_gamma=self.zero_centered_gamma, | |
| epsilon=self.layernorm_epsilon, | |
| axis=-1, | |
| features=self.num_attention_heads * self.head_dim, | |
| transpose_batch_sequence=self.transpose_batch_sequence, | |
| return_layernorm_output=(self.return_layernorm_output or is_self_attn), | |
| scale_axes=(W_NO_SHARD_AXES,), | |
| ln_bias_axes=(W_NO_SHARD_AXES,), | |
| kernel_axes=(W_FSDP_AXES, W_TP_AXES), | |
| use_bias=self.use_bias, | |
| bias_init=self.bias_init, | |
| bias_axes=(W_TP_AXES,), | |
| enable_low_rank_adaptation=lora_scope.qkv_proj, | |
| low_rank_adaptation_dim=self.low_rank_adaptation_dim, | |
| low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, | |
| dtype=self.dtype, | |
| kernel_init=query_init, | |
| layernorm_input_axes=inputs_logical_axes_maybe_sp, | |
| dot_input_axes=inputs_logical_axes_no_sp, | |
| name="query", | |
| )(inputs_q) | |
| if is_self_attn: | |
| assert ln_out is not None | |
| inputs_kv = ln_out | |
| kv_proj = DenseGeneral( | |
| axis=-1, | |
| features=(2, self.num_gqa_groups * self.head_dim), | |
| transpose_batch_sequence=self.transpose_batch_sequence, | |
| kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES), | |
| kernel_init=kv_init, | |
| use_bias=self.use_bias, | |
| bias_init=self.bias_init, | |
| bias_axes=(W_JOINED_AXES, W_TP_AXES), | |
| enable_low_rank_adaptation=lora_scope.qkv_proj, | |
| low_rank_adaptation_dim=self.low_rank_adaptation_dim, | |
| low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, | |
| name="kv", | |
| dtype=self.dtype, | |
| )(inputs_kv) | |
| kv_proj = checkpoint_name(kv_proj, "combined_kv_proj") | |
| qkv_layout = QKVLayout.BSHD_BS2HD | |
| else: | |
| kv_projection = functools.partial( | |
| DenseGeneral, | |
| axis=-1, | |
| features=self.num_gqa_groups * self.head_dim, | |
| transpose_batch_sequence=self.transpose_batch_sequence, | |
| kernel_axes=(W_FSDP_AXES, W_TP_AXES), | |
| use_bias=self.use_bias, | |
| bias_init=self.bias_init, | |
| bias_axes=(W_TP_AXES,), | |
| enable_low_rank_adaptation=lora_scope.qkv_proj, | |
| low_rank_adaptation_dim=self.low_rank_adaptation_dim, | |
| low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, | |
| dtype=self.dtype, | |
| ) | |
| query, ln_out = LayerNormDenseGeneral( | |
| enable_layernorm=self.input_layernorm, | |
| layernorm_type=self.layernorm_type, | |
| zero_centered_gamma=self.zero_centered_gamma, | |
| epsilon=self.layernorm_epsilon, | |
| axis=-1, | |
| features=self.num_attention_heads * self.head_dim, | |
| transpose_batch_sequence=self.transpose_batch_sequence, | |
| return_layernorm_output=True, | |
| scale_axes=(W_NO_SHARD_AXES,), | |
| ln_bias_axes=(W_NO_SHARD_AXES,), | |
| kernel_axes=(W_FSDP_AXES, W_TP_AXES), | |
| use_bias=self.use_bias, | |
| bias_init=self.bias_init, | |
| bias_axes=(W_TP_AXES,), | |
| enable_low_rank_adaptation=lora_scope.qkv_proj, | |
| low_rank_adaptation_dim=self.low_rank_adaptation_dim, | |
| low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, | |
| dtype=self.dtype, | |
| kernel_init=query_init, | |
| layernorm_input_axes=inputs_logical_axes_maybe_sp, | |
| dot_input_axes=inputs_logical_axes_no_sp, | |
| name="query", | |
| )(inputs_q) | |
| if is_self_attn: | |
| assert ln_out is not None | |
| inputs_kv = ln_out | |
| query = query.astype(input_dtype) | |
| key = kv_projection(kernel_init=self.kernel_init, name="key")(inputs_kv) | |
| key = key.astype(input_dtype) | |
| value = kv_projection(kernel_init=self.kernel_init, name="value")(inputs_kv) | |
| value = value.astype(input_dtype) | |
| query = checkpoint_name(query, "query_proj") | |
| key = checkpoint_name(key, "key_proj") | |
| value = checkpoint_name(value, "value_proj") | |
| qkv_layout = QKVLayout.BSHD_BSHD_BSHD | |
| if self.enable_rotary_pos_emb: | |
| if qkv_layout == QKVLayout.BS3HD: | |
| query, key, value = jnp.split(qkv_proj, [1, 2], axis=-2) | |
| elif qkv_layout == QKVLayout.BSHD_BS2HD: | |
| key, value = jnp.split(kv_proj, [1], axis=-2) | |
| else: | |
| assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD | |
| # No changes to memory layout, should trigger bitcast only (Ideally no Perf impact) | |
| query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim)) | |
| key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim)) | |
| query = rotary_pos_emb( | |
| query, | |
| self.rotary_pos_emb_windows, | |
| self.transpose_batch_sequence, | |
| self.rotary_pos_emb_group_method, | |
| ) | |
| key = rotary_pos_emb( | |
| key, | |
| self.rotary_pos_emb_windows, | |
| self.transpose_batch_sequence, | |
| self.rotary_pos_emb_group_method, | |
| ) | |
| qkv_layout = QKVLayout.BSHD_BSHD_BSHD | |
| if qkv_layout == QKVLayout.BSHD_BSHD_BSHD: | |
| query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim)) | |
| key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim)) | |
| value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim)) | |
| if decode: | |
| assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD | |
| is_initialized = self.has_variable("cache", "cached_key") | |
| cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) | |
| cached_value = self.variable( | |
| "cache", "cached_value", jnp.zeros, value.shape, value.dtype | |
| ) | |
| cache_index = self.variable( | |
| "cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32) | |
| ) | |
| if is_initialized: | |
| if self.transpose_batch_sequence: | |
| length, batch, num_attention_heads, head_dim = cached_key.value.shape | |
| expected_shape = (1, batch, num_attention_heads, head_dim) | |
| one_hot_indices_shape = (length, 1, 1, 1) | |
| else: | |
| batch, length, num_attention_heads, head_dim = cached_key.value.shape | |
| expected_shape = (batch, 1, num_attention_heads, head_dim) | |
| one_hot_indices_shape = (1, length, 1, 1) | |
| # Sanity shape check of cached key against input query. | |
| if expected_shape != query.shape: | |
| raise ValueError( | |
| "Autoregressive cache shape error, " | |
| f"expected query shape {expected_shape} instead got {query.shape}." | |
| ) | |
| cur_index = cache_index.value.astype(jnp.int32) | |
| one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype) | |
| one_hot_indices = jnp.reshape(one_hot_indices, one_hot_indices_shape) | |
| key = cached_key.value + key * one_hot_indices | |
| value = cached_value.value + value * one_hot_indices | |
| cached_key.value = key | |
| cached_value.value = value | |
| cache_index.value = cache_index.value + 1 | |
| mask = combine_masks( | |
| mask, jnp.broadcast_to(jnp.arange(length) > cur_index, (batch, 1, 1, length)) | |
| ) | |
| if bias is not None: | |
| dynamic_vector_slice_in_dim = vmap( | |
| lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None) | |
| ) | |
| bias = dynamic_vector_slice_in_dim( | |
| jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2 | |
| ) | |
| LEADING_AXES = (BATCH_AXES, SEQLEN_AXES) | |
| if self.transpose_batch_sequence: | |
| LEADING_AXES = (SEQLEN_AXES, BATCH_AXES) | |
| if qkv_layout == QKVLayout.BS3HD: | |
| qkv_proj = qkv_proj.reshape( | |
| *qkv_proj.shape[:2], 3, self.num_attention_heads, self.head_dim | |
| ) | |
| qkv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES) | |
| qkv_proj = with_sharding_constraint_by_logical_axes(qkv_proj, qkv_sharding_constraint) | |
| dpa_args = [qkv_proj, None, None] | |
| elif qkv_layout == QKVLayout.BSHD_BS2HD: | |
| query = query.reshape(*query.shape[:2], self.num_attention_heads, self.head_dim) | |
| kv_proj = kv_proj.reshape(*kv_proj.shape[:2], 2, self.num_gqa_groups, self.head_dim) | |
| q_sharding_constraint = (*LEADING_AXES, HEAD_AXES, HIDDEN_AXES) | |
| kv_sharding_constraint = (*LEADING_AXES, JOINED_AXES, HEAD_AXES, HIDDEN_AXES) | |
| query = with_sharding_constraint_by_logical_axes(query, q_sharding_constraint) | |
| kv_proj = with_sharding_constraint_by_logical_axes(kv_proj, kv_sharding_constraint) | |
| dpa_args = [query, kv_proj, None] | |
| else: | |
| assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD | |
| query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim)) | |
| key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim)) | |
| value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim)) | |
| qkv_sharding_constraint = (*LEADING_AXES, HEAD_AXES, HIDDEN_AXES) | |
| query = with_sharding_constraint_by_logical_axes(query, qkv_sharding_constraint) | |
| key = with_sharding_constraint_by_logical_axes(key, qkv_sharding_constraint) | |
| value = with_sharding_constraint_by_logical_axes(value, qkv_sharding_constraint) | |
| dpa_args = [query, key, value] | |
| scale_factor = 1.0 / sqrt(self.head_dim) if self.scale_attn_logits else 1.0 | |
| x = DotProductAttention( | |
| head_dim=self.head_dim, | |
| num_attention_heads=self.num_attention_heads, | |
| num_gqa_groups=self.num_gqa_groups, | |
| attn_mask_type=self.attn_mask_type, | |
| attn_bias_type=self.attn_bias_type, | |
| attention_dropout=self.attention_dropout, | |
| dtype=self.dtype, | |
| dropout_rng_name=self.dropout_rng_name, | |
| float32_logits=self.float32_logits, | |
| qkv_layout=qkv_layout.name, | |
| scale_factor=scale_factor, | |
| transpose_batch_sequence=self.transpose_batch_sequence, | |
| window_size=self.window_size, | |
| )(*dpa_args, mask, bias, deterministic=deterministic) | |
| x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3])) | |
| attn_context_sharding_constraint = (*LEADING_AXES, HIDDEN_TP_AXES) | |
| x = with_sharding_constraint_by_logical_axes(x, attn_context_sharding_constraint) | |
| out = DenseGeneral( | |
| features=inputs_q.shape[-1], | |
| transpose_batch_sequence=self.transpose_batch_sequence, | |
| axis=-1, | |
| kernel_init=self.kernel_init, | |
| kernel_axes=(W_TP_AXES, W_FSDP_AXES), | |
| use_bias=self.use_bias, | |
| bias_init=self.bias_init, | |
| bias_axes=(W_NO_SHARD_AXES,), | |
| enable_low_rank_adaptation=lora_scope.output_proj, | |
| low_rank_adaptation_dim=self.low_rank_adaptation_dim, | |
| low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, | |
| dtype=self.dtype, | |
| name="out", | |
| )(x) | |
| out = checkpoint_name(out, "out_proj") | |
| assert ( | |
| inputs_q.dtype == out.dtype | |
| ), f"output_dtype={out.dtype}, input_dtype={inputs_q.dtype}" | |
| return out, ln_out | |
| class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-methods | |
| """ | |
| T5-style relative positional embeddings to the attention logits. | |
| Parameters | |
| ---------- | |
| num_buckets: int | |
| The number of buckets to bucket distances between key and query positions into. | |
| max_distance: int | |
| The maximum distance before everything is lumped into the last | |
| distance bucket. | |
| num_attention_heads: int | |
| Number of attention heads in the transformer layer. | |
| embedding_init: Initializer, default = flax.linen.linear.default_embed_init | |
| Used for initializing relative embedding tables. | |
| embedding_axes: Tuple[str, ...], default = ('heads', 'relpos_buckets') | |
| The name of axes used to shard embedding attention bias with a corresponding mesh. | |
| Optimization parameters | |
| ----------------------- | |
| dtype: jax.numpy.dtype, default = jax.numpy.float32 | |
| The data type used to allocate the initial parameters. | |
| """ | |
| num_buckets: int | |
| max_distance: int | |
| num_attention_heads: int | |
| embedding_init: Callable[..., Array] = nn.linear.default_embed_init | |
| embedding_axes: Tuple[str, ...] = ("heads", "relpos_buckets") | |
| dtype: DType = jnp.float32 | |
| @nn.compact | |
| def __call__(self, q_seqlen, k_seqlen, bidirectional=True): | |
| """ | |
| Generate relative position embedding attention biases. | |
| Parameters | |
| ---------- | |
| q_seqlen: int | |
| The sequence length of query. | |
| k_seqlen: int | |
| The sequence length of key. | |
| bidirectional: bool, default = True | |
| Indicate whether to allow positive memory-query relative position | |
| embeddings. | |
| Returns | |
| ------- | |
| output: jax.numpy.ndarray | |
| An attention bias with shape `(1, num_attention_heads, q_seqlen, k_seqlen)`. | |
| """ | |
| context_position = np.arange(q_seqlen, dtype=jnp.int32)[:, None] | |
| memory_position = np.arange(k_seqlen, dtype=jnp.int32)[None, :] | |
| relative_position = memory_position - context_position | |
| # Compute relative position bucket | |
| rp_bucket = 0 | |
| negative_rp = -relative_position | |
| rpb_num_buckets = self.num_buckets | |
| if bidirectional: | |
| rpb_num_buckets //= 2 | |
| rp_bucket += (negative_rp < 0).astype(np.int32) * rpb_num_buckets | |
| negative_rp = np.abs(negative_rp) | |
| else: | |
| negative_rp = np.maximum(negative_rp, 0) | |
| rpb_max_exact = rpb_num_buckets // 2 | |
| rpb_is_small = negative_rp < rpb_max_exact | |
| rpb_val_if_large = rpb_max_exact + ( | |
| np.log(negative_rp.astype(np.float32) / rpb_max_exact + np.finfo(np.float32).eps) | |
| / np.log(self.max_distance / rpb_max_exact) | |
| * (rpb_num_buckets - rpb_max_exact) | |
| ).astype(np.int32) | |
| rpb_val_if_large = np.minimum(rpb_val_if_large, rpb_num_buckets - 1) | |
| rp_bucket += np.where(rpb_is_small, negative_rp, rpb_val_if_large) | |
| # Compute relative attention bias | |
| relative_attention_bias = self.param( | |
| "rel_embedding", | |
| nn.with_logical_partitioning(self.embedding_init, self.embedding_axes), | |
| (self.num_attention_heads, self.num_buckets), | |
| self.dtype, | |
| ) | |
| relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype) | |
| bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0) | |
| rp_bucket_one_hot = jnp.array(rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype) | |
| values = lax.dot_general( | |
| relative_attention_bias, rp_bucket_one_hot, (((1,), (0,)), ((), ())) | |
| ) | |
| return values[jnp.newaxis, ...] | |
| class TransformerLayerType(Enum): | |
| r""" | |
| TransformerLayerType is an Enum class to specify a type of TransformerLayer | |
| Values | |
| ---------- | |
| ENCODER: | |
| Encoder type of TransformerLayer. | |
| DECODER: | |
| Decoder type of TransformerLayer. | |
| """ | |
| ENCODER = "encoder" | |
| DECODER = "decoder" | |
| class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods | |
| r""" | |
| TransformerLayer is made up of a relative embedding, | |
| an attention block and a feedforward network (MLP). | |
| This standard layer is based on the paper “Attention Is All You Need”. | |
| Parameters | |
| ---------- | |
| hidden_size: int, default = 512 | |
| The hidden size of each input sample. | |
| mlp_hidden_size: int, default = 2048 | |
| Intermediate size to which input samples are projected. | |
| num_attention_heads: int, default = 8 | |
| Number of attention heads in the transformer layer. | |
| num_gqa_groups: int, default = `None` | |
| Number of GQA groups. When `None` is present, it is equal to num_attention_heads. | |
| Grouped Query Attention is described in | |
| `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_. | |
| This only affects the keys and values, not the querys. | |
| GQA-1 is equivalent to Multi-Query Attention | |
| (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H | |
| is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`. | |
| layernorm_type: {'layernorm', 'rmsnorm'}, default = 'layernorm' | |
| Indicate the type of layer normalization. | |
| layernorm_epsilon: float, default = 1e-6 | |
| A value added to the denominator of layer normalization for numerical stability. | |
| zero_centered_gamma: bool, default = False | |
| If set to `True`, the LayerNorm formula changes to | |
| .. math:: | |
| y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * | |
| (1 + \gamma) + \beta | |
| This parameter is only applicable for 'layernorm'. | |
| hidden_dropout: float, default = 0.1 | |
| Dropout probability for the dropout op after FC2 layer. | |
| hidden_dropout_dims: Sequence[int], default = () | |
| Dimensions that will share the same dropout mask for hidden | |
| attention_dropout: float, default = 0.1 | |
| Dropout probability for the dropout op during multi-head attention. | |
| intermediate_dropout: float, default = 0.1 | |
| Dropout probability for the dropout op after FC1 layer. | |
| intermediate_dropout_dims: Sequence[int], default = () | |
| Dimensions that will share the same dropout mask for hidden after FC1 layer. | |
| dropout_rng_name: str, default = 'dropout' | |
| The key in given RNGs via flax.linen.Module.apply that for | |
| generating Dropout masks in the Multi-Head Attention. | |
| mha_kernel_init: Initializer, default = | |
| flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'normal') | |
| Used for initializing weights of QKV and Output projection weights. | |
| It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). | |
| mlp_kernel_init: Initializer, default = | |
| flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') | |
| Used for initializing weights of FC1 and FC2 layers. | |
| It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). | |
| mlp_activations: Sequence[str], default = ('relu', ) | |
| The sequence of activation functions to apply after the first linear transformation. | |
| Each activation has its own transformation layer. | |
| use_bias: bool, default = False | |
| Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2. | |
| If set to False, the layer will not learn additive biases. | |
| bias_init: Initializer, default = flax.linen.initializers.zeros | |
| Used for initializing bias of QKVO projections, | |
| FC1 and FC2. It is only used when :attr:`use_bias=True`. | |
| It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). | |
| apply_residual_connection_post_layernorm: bool, default = False | |
| If set to True, residual connections are taken from the output | |
| of layer norm (default is taken from input of layer norm) | |
| output_layernorm: bool, default = False | |
| If set to True, layer normalization is applied on the output side, | |
| after the final dropout-add. default behavior is to apply layer | |
| normalization on the input side, before the QKV transformation. | |
| float32_attention_logits: bool, default = False | |
| Whether to compute attention logits in float32 for the unfused attention backend. | |
| For fused attention backend, the accumulation is always float32 without the perf overhead. | |
| layer_type: TransformerLayerType, default = TransformerLayerType.ENCODER | |
| If set to TransformerLayerType.DECODER, an additional cross-attention block | |
| is added after self-attention.this can be used for structures like `T5` | |
| Transformer in conjunction with the TransformerLayerType.ENCODER option. | |
| self_attn_mask_type: str, default = 'causal' | |
| This parameter specifies the type of attention mask to be applied during the softmax | |
| operation in the self attention. | |
| Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'} | |
| Each described below: | |
| * no_mask: No attention mask is applied. This means the self attention will consider the | |
| full sequence without any restrictions. | |
| * padding: Indicates the presence of padding at the end of each sequence. | |
| Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the | |
| :attr:`__call__` method to specify the padding positions. | |
| * causal: An upper triangular mask is applied to the softmax inputs, | |
| ensuring that the prediction for a certain position is only dependent on known outputs | |
| from positions before it. | |
| * causal_padding / padding_causal: A combination of both causal and padding masks. | |
| Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect. | |
| .. note:: :attr:`attention_mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'. | |
| self_attn_bias_type: Optional[str], default = None | |
| Type of the attention bias passed into the self attention. | |
| Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}. | |
| When default is present, the type is automatically decided by the MHA's bias parameter. | |
| Where it is `post_scale_bias` if there is bias. Otherwise `no_bias` is used. | |
| enable_relative_embedding: bool, default = True | |
| Whether to enable relative embedding as shifting of attention logits. | |
| relative_embedding: flax.linen.Module, default = None | |
| The module for relative embedding execution, only used when | |
| :attr:`enable_relative_embedding=True`. Default is None, which will create | |
| an instance of RelativePositionBiases if :attr:`enable_relative_embedding=True`. | |
| Default: RelativePositionBiases( num_buckets=32, max_distance=128, | |
| num_attention_heads=self.num_attention_heads, dtype=self.dtype, | |
| embedding_init=flax.linen.initializers.variance_scaling(1.0, 'fan_avg', 'uniform'), | |
| name='relpos_bias') | |
| enable_rotary_pos_emb: bool, default = False | |
| Whether to enable rotary position embedding to projected query and key in MHA. | |
| rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000) | |
| Indicate the min and max time-scales of rotary position embedding, | |
| only used when :attr:`enable_rotary_pos_emb=True` | |
| rotary_pos_emb_group_method: str, default = 'consecutive' | |
| Indicate the method to couple the coordinates. It should be one of | |
| ['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`, | |
| where :math:`d` is the hidden dimension. 'consecutive' pairs index :math:`i` with | |
| :math:`i + 1`. | |
| low_rank_adaptation_scope: str, default = 'none' | |
| Indicate the scope to apply low rank adaptation. It should be one of | |
| ['none', 'all', 'qkv_proj', 'output_proj', 'mlp', 'exclude_qkv_proj', | |
| 'exclude_output_proj', 'exclude_mlp'] | |
| low_rank_adaptation_dim: int, default = 32 | |
| The dimension for low rank adaptation, only used when | |
| :attr:`enable_low_rank_adaptation=True` | |
| low_rank_adaptation_alpha: float, default = None | |
| The alpha for computing the scaling factor of LoRA output. | |
| :math:`\frac{alpha}{rank} * lora\_output`. None means no scaling. | |
| enable_sequence_parallel: bool, default = False | |
| Whether to enable sequence parallelism to operations except dot. | |
| window_size: Optional[Tuple[int, int]], default = None | |
| Sliding window size. Default value is no sliding window. | |
| Optimization parameters | |
| ----------------------- | |
| dtype: jax.numpy.dtype, default = jax.numpy.float32 | |
| The data type used to allocate the initial parameters. | |
| drop_path: float, default = 0.0 | |
| When > 0.0, applies stochastic depth per sample in the main | |
| path of the residual block. | |
| fuse_qkv_params: bool, default = True | |
| If set to True, `TransformerLayer` module exposes a single fused | |
| parameter for query-key-value for self-attention and key-value for | |
| cross-attention. | |
| transpose_batch_sequence: bool, default = False | |
| Indicate whether the input tensors were switched axis of batch | |
| and sequence length dimension. if set to True, the input tensors | |
| should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). | |
| scale_attn_logits: bool, default = False | |
| Indicate whether to scale attention logits. | |
| if set to True, :math:`\frac{Q}{\sqrt{head_dim}*K}`, | |
| else :math:`Q*K` | |
| scaled_query_init: bool, default = `True` | |
| Whether to scale WQ on initialization by :math:`\sqrt{head_dim}` | |
| """ | |
| hidden_size: int = 512 | |
| mlp_hidden_size: int = 2048 | |
| num_attention_heads: int = 8 | |
| num_gqa_groups: Optional[int] = None | |
| layernorm_type: str = "layernorm" | |
| layernorm_epsilon: float = 1e-6 | |
| zero_centered_gamma: bool = False | |
| hidden_dropout: float = 0.1 | |
| hidden_dropout_dims: Sequence[int] = () | |
| attention_dropout: float = 0.1 | |
| intermediate_dropout: float = 0.1 | |
| intermediate_dropout_dims: Sequence[int] = () | |
| dropout_rng_name: str = "dropout" | |
| mha_kernel_init: Initializer = None | |
| mlp_kernel_init: Initializer = None | |
| mlp_activations: Sequence[str] = ("relu",) | |
| use_bias: bool = False | |
| bias_init: Initializer = nn.initializers.zeros | |
| apply_residual_connection_post_layernorm: bool = False | |
| output_layernorm: bool = False | |
| float32_attention_logits: bool = False | |
| layer_type: TransformerLayerType = TransformerLayerType.ENCODER | |
| self_attn_mask_type: str = "causal" | |
| self_attn_bias_type: Optional[str] = None | |
| enable_relative_embedding: bool = True | |
| relative_embedding: nn.Module = None | |
| enable_rotary_pos_emb: bool = False | |
| rotary_pos_emb_windows: Tuple[int, int] = (1, 10000) | |
| rotary_pos_emb_group_method: str = "consecutive" | |
| low_rank_adaptation_scope: str = "none" | |
| low_rank_adaptation_dim: int = 32 | |
| low_rank_adaptation_alpha: float = None | |
| dtype: DType = jnp.float32 | |
| drop_path: float = 0.0 | |
| fuse_qkv_params: bool = True | |
| transpose_batch_sequence: bool = False | |
| enable_sequence_parallel: bool = False | |
| scale_attn_logits: bool = False | |
| scaled_query_init: bool = True | |
| window_size: Optional[Tuple[int, int]] = None | |
| def __post_init__(self): | |
| if self.mha_kernel_init is None: | |
| self.mha_kernel_init = nn.initializers.variance_scaling( | |
| 1.0, "fan_in", "normal", dtype=self.dtype | |
| ) | |
| if self.mlp_kernel_init is None: | |
| self.mlp_kernel_init = nn.initializers.variance_scaling( | |
| 1.0, "fan_in", "truncated_normal", dtype=self.dtype | |
| ) | |
| if self.num_gqa_groups is None: | |
| self.num_gqa_groups = self.num_attention_heads | |
| super().__post_init__() | |
| @nn.compact | |
| def __call__( | |
| self, | |
| inputs: Array, | |
| encoded: Array = None, | |
| attention_mask: Array = None, | |
| encoder_decoder_mask: Array = None, | |
| deterministic: bool = False, | |
| decode: bool = False, | |
| max_decode_length: bool = None, | |
| ): | |
| """ | |
| Transformer Layer: attention block and a feedforward network (MLP) | |
| Parameters | |
| ---------- | |
| inputs: jax.numpy.ndarray | |
| Input tensor. | |
| encoded: jax.numpy.ndarray, default = None | |
| Output tensors of the encoder block to be fed into the decoder block if using | |
| :attr:`layer_type=TransformerLayerType.DECODER`. | |
| attention_mask : jax.numpy.ndarray, default = None | |
| Boolean tensor used to mask out self-attention softmax input. | |
| :attr:`True` means mask out the corresponding values. | |
| Ignored when :attr:`self.self_attn_mask_type` is either 'no_mask' or 'causal'. | |
| encoder_decoder_mask: jax.numpy.ndarray, default = None | |
| Boolean tensor used to mask out cross-attention softmax input when | |
| :attr:`layer_type=TransformerLayerType.DECODER`. | |
| :attr:`True` means mask out the corresponding values. | |
| deterministic: bool, default = False | |
| Disable dropout layers if set to True. | |
| decode: bool, default = False | |
| Indicate whether to prepare and use an autoregressive cache | |
| in Multi-head attention (MHA). | |
| max_decode_length: bool, default = None | |
| The maximum length to generate relative embedding biases when | |
| :attr:`layer_type=TransformerLayerType.DECODER` and | |
| :attr:`enable_relative_embedding=True`. | |
| Returns | |
| ------- | |
| outputs: jax.numpy.ndarray | |
| Output tensors. | |
| """ | |
| input_dtype = inputs.dtype | |
| assert ( | |
| self.layer_type in TransformerLayerType | |
| ), f"layer_type should be one of TransformerLayerType, but got {self.layer_type}." | |
| assert self.hidden_size % self.num_attention_heads == 0, ( | |
| "hidden_size should be multiples of num_attention_heads" | |
| f", but got {self.hidden_size=} and {self.num_attention_heads=}." | |
| ) | |
| assert self.layer_type == TransformerLayerType.DECODER or ( | |
| self.layer_type == TransformerLayerType.ENCODER and decode is False | |
| ), "decode should be False when layer_type == TransformerLayerType.ENCODER." | |
| head_dim = self.hidden_size // self.num_attention_heads | |
| sequence_dim = 0 if self.transpose_batch_sequence else 1 | |
| batch_dim = 1 - sequence_dim | |
| def generate_batch_seqlen_logical_axes(is_shared_seq=None): | |
| axes = [None, None] | |
| is_shared_seq = ( | |
| self.enable_sequence_parallel if is_shared_seq is None else is_shared_seq | |
| ) | |
| axes[batch_dim] = BATCH_AXES | |
| axes[sequence_dim] = SEQLEN_TP_AXES if is_shared_seq else SEQLEN_AXES | |
| return tuple(axes) | |
| attn_bias = None | |
| if self.enable_relative_embedding: | |
| if self.relative_embedding is None: | |
| rel_emb = RelativePositionBiases( | |
| num_buckets=32, | |
| max_distance=128, | |
| num_attention_heads=self.num_attention_heads, | |
| dtype=self.dtype, | |
| embedding_init=nn.initializers.variance_scaling( | |
| 1.0, "fan_avg", "uniform", dtype=self.dtype | |
| ), | |
| name="relpos_bias", | |
| ) | |
| else: | |
| rel_emb = self.relative_embedding | |
| if self.layer_type == TransformerLayerType.ENCODER: | |
| attn_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True) | |
| else: | |
| if decode and max_decode_length: | |
| l = max_decode_length | |
| else: | |
| l = inputs.shape[sequence_dim] | |
| attn_bias = rel_emb(l, l, False) | |
| assert inputs.ndim == 3 | |
| # Make name be the exactly same as T5X, since names would affect | |
| # RNGKey during init and apply. Myabe no need in the feature. | |
| if self.layer_type == TransformerLayerType.ENCODER: | |
| mha_name = "attention" | |
| else: | |
| mha_name = "self_attention" | |
| inputs = with_sharding_constraint_by_logical_axes( | |
| inputs, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES) | |
| ) | |
| # [batch, length, emb_dim] -> [batch, length, emb_dim] | |
| residual = inputs | |
| x, ln_out = MultiHeadAttention( | |
| num_attention_heads=self.num_attention_heads, | |
| dtype=self.dtype, | |
| head_dim=head_dim, | |
| num_gqa_groups=self.num_gqa_groups, | |
| transpose_batch_sequence=self.transpose_batch_sequence, | |
| enable_sequence_parallel=self.enable_sequence_parallel, | |
| attention_dropout=self.attention_dropout, | |
| dropout_rng_name=self.dropout_rng_name, | |
| float32_logits=self.float32_attention_logits, | |
| scale_attn_logits=self.scale_attn_logits, | |
| scaled_query_init=self.scaled_query_init, | |
| layernorm_type=self.layernorm_type, | |
| layernorm_epsilon=self.layernorm_epsilon, | |
| zero_centered_gamma=self.zero_centered_gamma, | |
| return_layernorm_output=self.apply_residual_connection_post_layernorm, | |
| input_layernorm=not self.output_layernorm, | |
| attn_mask_type=self.self_attn_mask_type, | |
| attn_bias_type=self.self_attn_bias_type, | |
| enable_rotary_pos_emb=self.enable_rotary_pos_emb, | |
| rotary_pos_emb_windows=self.rotary_pos_emb_windows, | |
| rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, | |
| low_rank_adaptation_scope=self.low_rank_adaptation_scope, | |
| low_rank_adaptation_dim=self.low_rank_adaptation_dim, | |
| low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, | |
| fuse_qkv_params=self.fuse_qkv_params, | |
| kernel_init=self.mha_kernel_init, | |
| use_bias=self.use_bias, | |
| bias_init=self.bias_init, | |
| name=mha_name, | |
| window_size=self.window_size, | |
| )(inputs, inputs, attention_mask, attn_bias, deterministic=deterministic, decode=decode) | |
| def hidden_dropout(x, deterministic): | |
| assert isinstance(self.hidden_dropout_dims, Sequence) | |
| x_shape_len = len(x.shape) | |
| for dims in self.hidden_dropout_dims: | |
| assert -x_shape_len <= dims < x_shape_len | |
| return nn.Dropout( | |
| rate=self.hidden_dropout, | |
| broadcast_dims=self.hidden_dropout_dims, | |
| rng_collection=self.dropout_rng_name, | |
| )(x, deterministic=deterministic) | |
| x = with_sharding_constraint_by_logical_axes( | |
| x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES) | |
| ) | |
| residual = with_sharding_constraint_by_logical_axes( | |
| residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES) | |
| ) | |
| x = hidden_dropout(x, deterministic) | |
| if self.drop_path > 0.0: | |
| drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim) | |
| x = nn.Dropout( | |
| rate=self.drop_path, | |
| broadcast_dims=drop_path_shape, | |
| rng_collection=self.dropout_rng_name, | |
| )(x, deterministic=deterministic) | |
| if self.apply_residual_connection_post_layernorm: | |
| assert ln_out is not None | |
| residual = ln_out | |
| x = x + residual | |
| mlp_input = x | |
| if self.layer_type == TransformerLayerType.DECODER: | |
| assert ( | |
| encoded is not None | |
| ), "encoded is required when layer_type == TransformerLayerType.DECODER." | |
| x = with_sharding_constraint_by_logical_axes( | |
| x, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES) | |
| ) | |
| residual = x | |
| y, ln_out = MultiHeadAttention( | |
| num_attention_heads=self.num_attention_heads, | |
| dtype=self.dtype, | |
| head_dim=head_dim, | |
| num_gqa_groups=self.num_gqa_groups, | |
| transpose_batch_sequence=self.transpose_batch_sequence, | |
| enable_sequence_parallel=self.enable_sequence_parallel, | |
| attention_dropout=self.attention_dropout, | |
| dropout_rng_name=self.dropout_rng_name, | |
| layernorm_type=self.layernorm_type, | |
| layernorm_epsilon=self.layernorm_epsilon, | |
| zero_centered_gamma=self.zero_centered_gamma, | |
| return_layernorm_output=self.apply_residual_connection_post_layernorm, | |
| input_layernorm=True, # Must do LayerNorm before MHA. | |
| attn_mask_type="padding", | |
| attn_bias_type="no_bias", | |
| enable_rotary_pos_emb=self.enable_rotary_pos_emb, | |
| rotary_pos_emb_windows=self.rotary_pos_emb_windows, | |
| rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, | |
| low_rank_adaptation_scope=self.low_rank_adaptation_scope, | |
| low_rank_adaptation_dim=self.low_rank_adaptation_dim, | |
| low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, | |
| float32_logits=self.float32_attention_logits, | |
| scale_attn_logits=self.scale_attn_logits, | |
| scaled_query_init=self.scaled_query_init, | |
| fuse_qkv_params=self.fuse_qkv_params, | |
| kernel_init=self.mha_kernel_init, | |
| use_bias=self.use_bias, | |
| bias_init=self.bias_init, | |
| name="encoder_decoder_attention", | |
| window_size=self.window_size, | |
| )(x, encoded, encoder_decoder_mask, deterministic=deterministic) | |
| y = with_sharding_constraint_by_logical_axes( | |
| y, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES) | |
| ) | |
| residual = with_sharding_constraint_by_logical_axes( | |
| residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES) | |
| ) | |
| y = hidden_dropout(y, deterministic) | |
| if self.apply_residual_connection_post_layernorm: | |
| assert ln_out is not None | |
| residual = ln_out | |
| mlp_input = y + residual | |
| mlp_input = with_sharding_constraint_by_logical_axes( | |
| mlp_input, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES) | |
| ) | |
| lora_scope = _canonicalize_lora_scope(self.low_rank_adaptation_scope) | |
| # MlpBlock | |
| residual = mlp_input | |
| z, ln_out = LayerNormMLP( | |
| layernorm_type=self.layernorm_type, | |
| zero_centered_gamma=self.zero_centered_gamma, | |
| epsilon=self.layernorm_epsilon, | |
| transpose_batch_sequence=self.transpose_batch_sequence, | |
| return_layernorm_output=self.apply_residual_connection_post_layernorm, | |
| intermediate_dim=self.mlp_hidden_size, | |
| activations=self.mlp_activations, | |
| intermediate_dropout_rng_name=self.dropout_rng_name, | |
| intermediate_dropout_rate=self.intermediate_dropout, | |
| intermediate_hidden_dropout_dims=self.intermediate_dropout_dims, | |
| dtype=self.dtype, | |
| scale_axes=(W_NO_SHARD_AXES,), | |
| ln_bias_axes=(W_NO_SHARD_AXES,), | |
| kernel_init=self.mlp_kernel_init, | |
| kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES), | |
| kernel_axes_2=(W_TP_AXES, W_FSDP_AXES), | |
| use_bias=self.use_bias, | |
| bias_init=self.bias_init, | |
| bias_axes_1=(W_JOINED_AXES, W_TP_AXES), | |
| bias_axes_2=(W_NO_SHARD_AXES,), | |
| enable_low_rank_adaptation=lora_scope.mlp, | |
| low_rank_adaptation_dim=self.low_rank_adaptation_dim, | |
| low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, | |
| layernorm_input_axes=(*generate_batch_seqlen_logical_axes(), HIDDEN_AXES), | |
| dot_1_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES), | |
| dot_2_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_TP_AXES), | |
| name="mlp", | |
| )(mlp_input, deterministic=deterministic) | |
| if self.apply_residual_connection_post_layernorm: | |
| assert ln_out is not None | |
| residual = ln_out | |
| z = with_sharding_constraint_by_logical_axes( | |
| z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES) | |
| ) | |
| residual = with_sharding_constraint_by_logical_axes( | |
| residual, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES) | |
| ) | |
| z = hidden_dropout(z, deterministic) | |
| if self.drop_path > 0.0: | |
| drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim) | |
| z = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)( | |
| z, deterministic=deterministic | |
| ) | |
| z = z + residual | |
| if self.output_layernorm: | |
| z = with_sharding_constraint_by_logical_axes( | |
| z, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES) | |
| ) | |
| z = LayerNorm( | |
| layernorm_type=self.layernorm_type, | |
| zero_centered_gamma=self.zero_centered_gamma, | |
| epsilon=self.layernorm_epsilon, | |
| scale_axes=(W_NO_SHARD_AXES,), | |
| bias_axes=(W_NO_SHARD_AXES,), | |
| transpose_batch_sequence=self.transpose_batch_sequence, | |
| dtype=self.dtype, | |
| name="output_layernorm", | |
| )(z) | |
| assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}" | |
| return z | 
  
    Sign up for free
    to join this conversation on GitHub.
    Already have an account?
    Sign in to comment