Last active
March 5, 2026 13:41
-
-
Save edp1096/8670b744d88fddf89da0d0bc4ac56f95 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # after 1076f97 | |
| # https://github.com/karpathy/nanochat/commit/1076f97059785ed6d763706bf2304ce7721ab75c | |
| """ | |
| Unified Flash Attention interface with automatic FA3/FA2/SDPA switching. | |
| Exports `flash_attn` module that matches the FA3 API exactly, falls back | |
| to FA2 (with sliding window support) on non-Hopper GPUs, then SDPA. | |
| Usage (drop-in replacement for FA3): | |
| from nanochat.flash_attention import flash_attn | |
| # Training (no KV cache) | |
| y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size) | |
| # Inference (with KV cache) | |
| y = flash_attn.flash_attn_with_kvcache(q, k_cache, v_cache, k=k, v=v, ...) | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| # ============================================================================= | |
| # Detection: Try to load FA3 on Hopper+ GPUs | |
| # ============================================================================= | |
| def _load_flash_attention_3(): | |
| """Try to load Flash Attention 3 (requires Hopper GPU, sm90).""" | |
| if not torch.cuda.is_available(): | |
| return None | |
| try: | |
| major, _ = torch.cuda.get_device_capability() | |
| # FA3 kernels are compiled for Hopper (sm90) only | |
| # Ada (sm89), Blackwell (sm100/sm121) need FA2 or SDPA fallback | |
| if major != 9: | |
| return None | |
| import os | |
| os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" | |
| from kernels import get_kernel | |
| return get_kernel('varunneal/flash-attention-3').flash_attn_interface | |
| except Exception: | |
| return None | |
| # ============================================================================= | |
| # Detection: Try to load FA2 | |
| # FA2 uses identical (B, T, H, D) layout as FA3. | |
| # sm120/sm121 (GB10 DGX Spark) can run sm120-compiled FA2 kernels. | |
| # ============================================================================= | |
| def _load_flash_attention_2(): | |
| """Try to load Flash Attention 2.""" | |
| if not torch.cuda.is_available(): | |
| return None | |
| try: | |
| from flash_attn import flash_attn_func as _fa2_func | |
| from flash_attn import flash_attn_with_kvcache as _fa2_kvcache | |
| from types import SimpleNamespace | |
| return SimpleNamespace( | |
| flash_attn_func=_fa2_func, | |
| flash_attn_with_kvcache=_fa2_kvcache, | |
| ) | |
| except Exception: | |
| return None | |
| _fa3 = _load_flash_attention_3() | |
| _fa2 = None if _fa3 is not None else _load_flash_attention_2() | |
| HAS_FA2 = _fa2 is not None | |
| # HAS_FA3 is True when either FA3 or FA2 is available, | |
| # so external files (base_train.py, chat_sft.py) won't show "FA3 not available" warnings. | |
| HAS_FA3 = _fa3 is not None or HAS_FA2 | |
| # Override for testing: set to 'fa3', 'fa2', 'sdpa', or None (auto) | |
| _override_impl = None | |
| def _resolve_use_fa3(): | |
| """Decide once whether to use FA3/FA2, based on availability, override, and dtype.""" | |
| if _override_impl == 'fa3': | |
| assert _fa3 is not None, "Cannot override to FA3: not available on this hardware" | |
| return True | |
| if _override_impl == 'fa2': | |
| assert _fa2 is not None, "Cannot override to FA2: not available on this hardware" | |
| return True | |
| if _override_impl == 'sdpa': | |
| return False | |
| if _fa3 is not None: | |
| # FA3 Hopper kernels only support bf16 and fp8; fp16/fp32 must use SDPA fallback | |
| from nanochat.common import COMPUTE_DTYPE | |
| if COMPUTE_DTYPE == torch.bfloat16: | |
| return True | |
| return False | |
| if _fa2 is not None: | |
| return True | |
| return False | |
| USE_FA3 = _resolve_use_fa3() | |
| # ============================================================================= | |
| # Internal: pick actual implementation (fa3 vs fa2) when USE_FA3 is True | |
| # ============================================================================= | |
| def _get_impl(): | |
| """Return the actual flash attention module to use.""" | |
| if _override_impl == 'fa3' and _fa3 is not None: | |
| return _fa3 | |
| if _override_impl == 'fa2' and _fa2 is not None: | |
| return _fa2 | |
| if _fa3 is not None: | |
| from nanochat.common import COMPUTE_DTYPE | |
| if COMPUTE_DTYPE == torch.bfloat16: | |
| return _fa3 | |
| if _fa2 is not None: | |
| return _fa2 | |
| return None | |
| _fa_impl = _get_impl() | |
| # ============================================================================= | |
| # SDPA helpers | |
| # ============================================================================= | |
| def _sdpa_attention(q, k, v, window_size, enable_gqa): | |
| """ | |
| SDPA attention with sliding window support. | |
| q, k, v are (B, H, T, D) format. | |
| """ | |
| Tq = q.size(2) | |
| Tk = k.size(2) | |
| window = window_size[0] | |
| # Full context, same length | |
| if (window < 0 or window >= Tq) and Tq == Tk: | |
| return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa) | |
| # Single token generation | |
| if Tq == 1: | |
| if window >= 0 and window < Tk: | |
| # window is "left" tokens we need to include (window + 1) keys total | |
| start = max(0, Tk - (window + 1)) | |
| k = k[:, :, start:, :] | |
| v = v[:, :, start:, :] | |
| return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa) | |
| # Need explicit mask for sliding window/chunk inference | |
| device = q.device | |
| # For chunk inference (Tq != Tk), is_causal is not aligned to cache position => build an explicit bool mask | |
| row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1) | |
| col_idx = torch.arange(Tk, device=device).unsqueeze(0) | |
| mask = col_idx <= row_idx | |
| # sliding window (left) | |
| if window >= 0 and window < Tk: | |
| mask = mask & ((row_idx - col_idx) <= window) | |
| return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa) | |
| # ============================================================================= | |
| # Public API: Same interface as FA3 | |
| # ============================================================================= | |
| def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)): | |
| """ | |
| Flash Attention for training (no KV cache). | |
| Args: | |
| q, k, v: Tensors of shape (B, T, H, D) | |
| causal: Whether to use causal masking | |
| window_size: (left, right) sliding window. -1 means unlimited. | |
| Returns: | |
| Output tensor of shape (B, T, H, D) | |
| """ | |
| if USE_FA3: | |
| if _fa_impl is _fa3: | |
| return _fa_impl.flash_attn_func(q, k, v, causal=causal, window_size=window_size) | |
| else: | |
| return _fa_impl.flash_attn_func(q, k, v, dropout_p=0.0, causal=causal, window_size=window_size) | |
| # SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D) | |
| q = q.transpose(1, 2) | |
| k = k.transpose(1, 2) | |
| v = v.transpose(1, 2) | |
| enable_gqa = q.size(1) != k.size(1) | |
| y = _sdpa_attention(q, k, v, window_size, enable_gqa) | |
| return y.transpose(1, 2) # back to (B, T, H, D) | |
| def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=None, | |
| causal=False, window_size=(-1, -1)): | |
| """ | |
| Flash Attention with KV cache for inference. | |
| FA3/FA2 update k_cache/v_cache in-place. SDPA fallback does the same manually. | |
| Args: | |
| q: Queries, shape (B, T_new, H, D) | |
| k_cache, v_cache: Pre-allocated cache tensors, shape (B, T_max, H_kv, D) | |
| k, v: New keys/values to insert, shape (B, T_new, H_kv, D) | |
| cache_seqlens: Current position in cache, shape (B,) int32 | |
| causal: Whether to use causal masking | |
| window_size: (left, right) sliding window. -1 means unlimited. | |
| Returns: | |
| Output tensor of shape (B, T_new, H, D) | |
| """ | |
| if USE_FA3: | |
| return _fa_impl.flash_attn_with_kvcache( | |
| q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens, | |
| causal=causal, window_size=window_size | |
| ) | |
| # SDPA fallback: manually manage KV cache | |
| B, T_new, H, D = q.shape | |
| pos = cache_seqlens[0].item() # assume uniform position across batch | |
| # Insert new k, v into cache (in-place, matching FA3 behavior) | |
| if k is not None and v is not None: | |
| k_cache[:, pos:pos+T_new, :, :] = k | |
| v_cache[:, pos:pos+T_new, :, :] = v | |
| # Get full cache up to current position + new tokens | |
| end_pos = pos + T_new | |
| k_full = k_cache[:, :end_pos, :, :] | |
| v_full = v_cache[:, :end_pos, :, :] | |
| # Transpose to SDPA layout: (B, T, H, D) -> (B, H, T, D) | |
| q_sdpa = q.transpose(1, 2) | |
| k_sdpa = k_full.transpose(1, 2) | |
| v_sdpa = v_full.transpose(1, 2) | |
| enable_gqa = q_sdpa.size(1) != k_sdpa.size(1) | |
| y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa) | |
| return y_sdpa.transpose(1, 2) # back to (B, T, H, D) | |
| # ============================================================================= | |
| # Export: flash_attn module interface (drop-in replacement for FA3) | |
| # ============================================================================= | |
| from types import SimpleNamespace | |
| flash_attn = SimpleNamespace( | |
| flash_attn_func=flash_attn_func, | |
| flash_attn_with_kvcache=flash_attn_with_kvcache, | |
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # before c7ba252 | |
| # https://github.com/karpathy/nanochat/commit/c7ba25214276d165eeefca7cb2060587975db189 | |
| """ | |
| Unified Flash Attention interface with automatic FA3/FA2/SDPA switching. | |
| Exports `flash_attn` module that matches the FA3 API exactly, falls back | |
| to FA2 (with sliding window support) on non-Hopper GPUs, then SDPA. | |
| Usage (drop-in replacement for FA3): | |
| from nanochat.flash_attention import flash_attn | |
| # Training (no KV cache) | |
| y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size) | |
| # Inference (with KV cache) | |
| y = flash_attn.flash_attn_with_kvcache(q, k_cache, v_cache, k=k, v=v, ...) | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| # ============================================================================= | |
| # Detection: Try to load FA3 on Hopper GPUs (sm90 only) | |
| # ============================================================================= | |
| def _load_flash_attention_3(): | |
| """Try to load Flash Attention 3 (requires Hopper GPU, sm90).""" | |
| if not torch.cuda.is_available(): | |
| return None | |
| try: | |
| major, _ = torch.cuda.get_device_capability() | |
| # FA3 kernels are compiled for Hopper (sm90) only | |
| # Ada (sm89), Blackwell (sm100/sm121) need FA2 or SDPA fallback | |
| if major != 9: | |
| return None | |
| import os | |
| os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" | |
| from kernels import get_kernel | |
| return get_kernel('varunneal/flash-attention-3').flash_attn_interface | |
| except Exception: | |
| return None | |
| # ============================================================================= | |
| # Detection: Try to load FA2 | |
| # FA2 uses identical (B, T, H, D) layout as FA3. | |
| # sm120/sm121 (GB10 DGX Spark) can run sm120-compiled FA2 kernels. | |
| # ============================================================================= | |
| def _load_flash_attention_2(): | |
| """Try to load Flash Attention 2.""" | |
| if not torch.cuda.is_available(): | |
| return None | |
| try: | |
| from flash_attn import flash_attn_func as _fa2_func | |
| from flash_attn import flash_attn_with_kvcache as _fa2_kvcache | |
| from types import SimpleNamespace | |
| return SimpleNamespace( | |
| flash_attn_func=_fa2_func, | |
| flash_attn_with_kvcache=_fa2_kvcache, | |
| ) | |
| except Exception: | |
| return None | |
| _fa3 = _load_flash_attention_3() | |
| _fa2 = None if _fa3 is not None else _load_flash_attention_2() | |
| HAS_FA2 = _fa2 is not None | |
| HAS_FA3 = _fa3 is not None or HAS_FA2 | |
| # Override for testing: set to 'fa3', 'fa2', 'sdpa', or None (auto) | |
| _override_impl = None | |
| def _use_fa3(): | |
| """Determine whether to use FA3.""" | |
| if _override_impl == 'fa3': | |
| assert HAS_FA3, "Cannot override to FA3: not available on this hardware" | |
| return True | |
| if _override_impl in ('fa2', 'sdpa'): | |
| return False | |
| return HAS_FA3 | |
| def _use_fa2(): | |
| """Determine whether to use FA2.""" | |
| if _override_impl == 'fa2': | |
| assert HAS_FA2, "Cannot override to FA2: not available on this hardware" | |
| return True | |
| if _override_impl in ('fa3', 'sdpa'): | |
| return False | |
| return HAS_FA2 and not _use_fa3() | |
| # ============================================================================= | |
| # SDPA helpers | |
| # ============================================================================= | |
| def _sdpa_attention(q, k, v, window_size, enable_gqa): | |
| """ | |
| SDPA attention with sliding window support. | |
| q, k, v are (B, H, T, D) format. | |
| """ | |
| Tq = q.size(2) | |
| Tk = k.size(2) | |
| window = window_size[0] | |
| # Full context, same length | |
| if (window < 0 or window >= Tq) and Tq == Tk: | |
| return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa) | |
| # Single token generation | |
| if Tq == 1: | |
| if window >= 0 and window < Tk: | |
| # Keep only the last (window + 1) key/value tokens | |
| start = max(0, Tk - (window + 1)) | |
| k = k[:, :, start:, :] | |
| v = v[:, :, start:, :] | |
| return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa) | |
| # Need explicit mask for sliding window / chunk inference (Tq != Tk) | |
| device = q.device | |
| row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1) | |
| col_idx = torch.arange(Tk, device=device).unsqueeze(0) | |
| mask = col_idx <= row_idx | |
| # Apply sliding window constraint | |
| if window >= 0 and window < Tk: | |
| mask = mask & ((row_idx - col_idx) <= window) | |
| return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa) | |
| # ============================================================================= | |
| # Public API: Same interface as FA3 | |
| # ============================================================================= | |
| def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)): | |
| """ | |
| Flash Attention for training (no KV cache). | |
| Args: | |
| q, k, v: Tensors of shape (B, T, H, D) -- identical layout in FA2 and FA3 | |
| causal: Whether to use causal masking | |
| window_size: (left, right) sliding window. -1 means unlimited. | |
| Returns: | |
| Output tensor of shape (B, T, H, D) | |
| """ | |
| if _use_fa3(): | |
| return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size) | |
| # FA2: identical (B, T, H, D) layout and same window_size parameter | |
| if _use_fa2(): | |
| return _fa2.flash_attn_func(q, k, v, dropout_p=0.0, causal=causal, window_size=window_size) | |
| # SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D) | |
| q = q.transpose(1, 2) | |
| k = k.transpose(1, 2) | |
| v = v.transpose(1, 2) | |
| enable_gqa = q.size(1) != k.size(1) | |
| y = _sdpa_attention(q, k, v, window_size, enable_gqa) | |
| return y.transpose(1, 2) # back to (B, T, H, D) | |
| def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=None, | |
| causal=False, window_size=(-1, -1)): | |
| """ | |
| Flash Attention with KV cache for inference. | |
| FA3/FA2 update k_cache/v_cache in-place. SDPA fallback does the same manually. | |
| Args: | |
| q: Queries, shape (B, T_new, H, D) | |
| k_cache, v_cache: Pre-allocated cache tensors, shape (B, T_max, H_kv, D) | |
| k, v: New keys/values to insert, shape (B, T_new, H_kv, D) | |
| cache_seqlens: Current position in cache, shape (B,) int32 | |
| causal: Whether to use causal masking | |
| window_size: (left, right) sliding window. -1 means unlimited. | |
| Returns: | |
| Output tensor of shape (B, T_new, H, D) | |
| """ | |
| if _use_fa3(): | |
| return _fa3.flash_attn_with_kvcache( | |
| q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens, | |
| causal=causal, window_size=window_size, | |
| ) | |
| # FA2: identical API signature | |
| if _use_fa2(): | |
| return _fa2.flash_attn_with_kvcache( | |
| q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens, | |
| causal=causal, window_size=window_size, | |
| ) | |
| # SDPA fallback: manually manage KV cache (mirrors FA3/FA2 in-place update) | |
| B, T_new, H, D = q.shape | |
| pos = cache_seqlens[0].item() # assume uniform position across batch | |
| # Insert new k, v into cache in-place | |
| if k is not None and v is not None: | |
| k_cache[:, pos:pos + T_new, :, :] = k | |
| v_cache[:, pos:pos + T_new, :, :] = v | |
| # Full cache up to current position + new tokens | |
| end_pos = pos + T_new | |
| k_full = k_cache[:, :end_pos, :, :] | |
| v_full = v_cache[:, :end_pos, :, :] | |
| # Transpose to SDPA layout: (B, T, H, D) -> (B, H, T, D) | |
| q_sdpa = q.transpose(1, 2) | |
| k_sdpa = k_full.transpose(1, 2) | |
| v_sdpa = v_full.transpose(1, 2) | |
| enable_gqa = q_sdpa.size(1) != k_sdpa.size(1) | |
| y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa) | |
| return y_sdpa.transpose(1, 2) # back to (B, T, H, D) | |
| # ============================================================================= | |
| # Export: flash_attn module interface (drop-in replacement for FA3) | |
| # ============================================================================= | |
| from types import SimpleNamespace | |
| flash_attn = SimpleNamespace( | |
| flash_attn_func=flash_attn_func, | |
| flash_attn_with_kvcache=flash_attn_with_kvcache, | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment