Skip to content

Instantly share code, notes, and snippets.

@edp1096
Last active March 5, 2026 13:41
Show Gist options
  • Select an option

  • Save edp1096/8670b744d88fddf89da0d0bc4ac56f95 to your computer and use it in GitHub Desktop.

Select an option

Save edp1096/8670b744d88fddf89da0d0bc4ac56f95 to your computer and use it in GitHub Desktop.
# after 1076f97
# https://github.com/karpathy/nanochat/commit/1076f97059785ed6d763706bf2304ce7721ab75c
"""
Unified Flash Attention interface with automatic FA3/FA2/SDPA switching.
Exports `flash_attn` module that matches the FA3 API exactly, falls back
to FA2 (with sliding window support) on non-Hopper GPUs, then SDPA.
Usage (drop-in replacement for FA3):
from nanochat.flash_attention import flash_attn
# Training (no KV cache)
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
# Inference (with KV cache)
y = flash_attn.flash_attn_with_kvcache(q, k_cache, v_cache, k=k, v=v, ...)
"""
import torch
import torch.nn.functional as F
# =============================================================================
# Detection: Try to load FA3 on Hopper+ GPUs
# =============================================================================
def _load_flash_attention_3():
"""Try to load Flash Attention 3 (requires Hopper GPU, sm90)."""
if not torch.cuda.is_available():
return None
try:
major, _ = torch.cuda.get_device_capability()
# FA3 kernels are compiled for Hopper (sm90) only
# Ada (sm89), Blackwell (sm100/sm121) need FA2 or SDPA fallback
if major != 9:
return None
import os
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
from kernels import get_kernel
return get_kernel('varunneal/flash-attention-3').flash_attn_interface
except Exception:
return None
# =============================================================================
# Detection: Try to load FA2
# FA2 uses identical (B, T, H, D) layout as FA3.
# sm120/sm121 (GB10 DGX Spark) can run sm120-compiled FA2 kernels.
# =============================================================================
def _load_flash_attention_2():
"""Try to load Flash Attention 2."""
if not torch.cuda.is_available():
return None
try:
from flash_attn import flash_attn_func as _fa2_func
from flash_attn import flash_attn_with_kvcache as _fa2_kvcache
from types import SimpleNamespace
return SimpleNamespace(
flash_attn_func=_fa2_func,
flash_attn_with_kvcache=_fa2_kvcache,
)
except Exception:
return None
_fa3 = _load_flash_attention_3()
_fa2 = None if _fa3 is not None else _load_flash_attention_2()
HAS_FA2 = _fa2 is not None
# HAS_FA3 is True when either FA3 or FA2 is available,
# so external files (base_train.py, chat_sft.py) won't show "FA3 not available" warnings.
HAS_FA3 = _fa3 is not None or HAS_FA2
# Override for testing: set to 'fa3', 'fa2', 'sdpa', or None (auto)
_override_impl = None
def _resolve_use_fa3():
"""Decide once whether to use FA3/FA2, based on availability, override, and dtype."""
if _override_impl == 'fa3':
assert _fa3 is not None, "Cannot override to FA3: not available on this hardware"
return True
if _override_impl == 'fa2':
assert _fa2 is not None, "Cannot override to FA2: not available on this hardware"
return True
if _override_impl == 'sdpa':
return False
if _fa3 is not None:
# FA3 Hopper kernels only support bf16 and fp8; fp16/fp32 must use SDPA fallback
from nanochat.common import COMPUTE_DTYPE
if COMPUTE_DTYPE == torch.bfloat16:
return True
return False
if _fa2 is not None:
return True
return False
USE_FA3 = _resolve_use_fa3()
# =============================================================================
# Internal: pick actual implementation (fa3 vs fa2) when USE_FA3 is True
# =============================================================================
def _get_impl():
"""Return the actual flash attention module to use."""
if _override_impl == 'fa3' and _fa3 is not None:
return _fa3
if _override_impl == 'fa2' and _fa2 is not None:
return _fa2
if _fa3 is not None:
from nanochat.common import COMPUTE_DTYPE
if COMPUTE_DTYPE == torch.bfloat16:
return _fa3
if _fa2 is not None:
return _fa2
return None
_fa_impl = _get_impl()
# =============================================================================
# SDPA helpers
# =============================================================================
def _sdpa_attention(q, k, v, window_size, enable_gqa):
"""
SDPA attention with sliding window support.
q, k, v are (B, H, T, D) format.
"""
Tq = q.size(2)
Tk = k.size(2)
window = window_size[0]
# Full context, same length
if (window < 0 or window >= Tq) and Tq == Tk:
return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
# Single token generation
if Tq == 1:
if window >= 0 and window < Tk:
# window is "left" tokens we need to include (window + 1) keys total
start = max(0, Tk - (window + 1))
k = k[:, :, start:, :]
v = v[:, :, start:, :]
return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
# Need explicit mask for sliding window/chunk inference
device = q.device
# For chunk inference (Tq != Tk), is_causal is not aligned to cache position => build an explicit bool mask
row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1)
col_idx = torch.arange(Tk, device=device).unsqueeze(0)
mask = col_idx <= row_idx
# sliding window (left)
if window >= 0 and window < Tk:
mask = mask & ((row_idx - col_idx) <= window)
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
# =============================================================================
# Public API: Same interface as FA3
# =============================================================================
def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
"""
Flash Attention for training (no KV cache).
Args:
q, k, v: Tensors of shape (B, T, H, D)
causal: Whether to use causal masking
window_size: (left, right) sliding window. -1 means unlimited.
Returns:
Output tensor of shape (B, T, H, D)
"""
if USE_FA3:
if _fa_impl is _fa3:
return _fa_impl.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
else:
return _fa_impl.flash_attn_func(q, k, v, dropout_p=0.0, causal=causal, window_size=window_size)
# SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
enable_gqa = q.size(1) != k.size(1)
y = _sdpa_attention(q, k, v, window_size, enable_gqa)
return y.transpose(1, 2) # back to (B, T, H, D)
def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=None,
causal=False, window_size=(-1, -1)):
"""
Flash Attention with KV cache for inference.
FA3/FA2 update k_cache/v_cache in-place. SDPA fallback does the same manually.
Args:
q: Queries, shape (B, T_new, H, D)
k_cache, v_cache: Pre-allocated cache tensors, shape (B, T_max, H_kv, D)
k, v: New keys/values to insert, shape (B, T_new, H_kv, D)
cache_seqlens: Current position in cache, shape (B,) int32
causal: Whether to use causal masking
window_size: (left, right) sliding window. -1 means unlimited.
Returns:
Output tensor of shape (B, T_new, H, D)
"""
if USE_FA3:
return _fa_impl.flash_attn_with_kvcache(
q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
causal=causal, window_size=window_size
)
# SDPA fallback: manually manage KV cache
B, T_new, H, D = q.shape
pos = cache_seqlens[0].item() # assume uniform position across batch
# Insert new k, v into cache (in-place, matching FA3 behavior)
if k is not None and v is not None:
k_cache[:, pos:pos+T_new, :, :] = k
v_cache[:, pos:pos+T_new, :, :] = v
# Get full cache up to current position + new tokens
end_pos = pos + T_new
k_full = k_cache[:, :end_pos, :, :]
v_full = v_cache[:, :end_pos, :, :]
# Transpose to SDPA layout: (B, T, H, D) -> (B, H, T, D)
q_sdpa = q.transpose(1, 2)
k_sdpa = k_full.transpose(1, 2)
v_sdpa = v_full.transpose(1, 2)
enable_gqa = q_sdpa.size(1) != k_sdpa.size(1)
y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa)
return y_sdpa.transpose(1, 2) # back to (B, T, H, D)
# =============================================================================
# Export: flash_attn module interface (drop-in replacement for FA3)
# =============================================================================
from types import SimpleNamespace
flash_attn = SimpleNamespace(
flash_attn_func=flash_attn_func,
flash_attn_with_kvcache=flash_attn_with_kvcache,
)
# before c7ba252
# https://github.com/karpathy/nanochat/commit/c7ba25214276d165eeefca7cb2060587975db189
"""
Unified Flash Attention interface with automatic FA3/FA2/SDPA switching.
Exports `flash_attn` module that matches the FA3 API exactly, falls back
to FA2 (with sliding window support) on non-Hopper GPUs, then SDPA.
Usage (drop-in replacement for FA3):
from nanochat.flash_attention import flash_attn
# Training (no KV cache)
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
# Inference (with KV cache)
y = flash_attn.flash_attn_with_kvcache(q, k_cache, v_cache, k=k, v=v, ...)
"""
import torch
import torch.nn.functional as F
# =============================================================================
# Detection: Try to load FA3 on Hopper GPUs (sm90 only)
# =============================================================================
def _load_flash_attention_3():
"""Try to load Flash Attention 3 (requires Hopper GPU, sm90)."""
if not torch.cuda.is_available():
return None
try:
major, _ = torch.cuda.get_device_capability()
# FA3 kernels are compiled for Hopper (sm90) only
# Ada (sm89), Blackwell (sm100/sm121) need FA2 or SDPA fallback
if major != 9:
return None
import os
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
from kernels import get_kernel
return get_kernel('varunneal/flash-attention-3').flash_attn_interface
except Exception:
return None
# =============================================================================
# Detection: Try to load FA2
# FA2 uses identical (B, T, H, D) layout as FA3.
# sm120/sm121 (GB10 DGX Spark) can run sm120-compiled FA2 kernels.
# =============================================================================
def _load_flash_attention_2():
"""Try to load Flash Attention 2."""
if not torch.cuda.is_available():
return None
try:
from flash_attn import flash_attn_func as _fa2_func
from flash_attn import flash_attn_with_kvcache as _fa2_kvcache
from types import SimpleNamespace
return SimpleNamespace(
flash_attn_func=_fa2_func,
flash_attn_with_kvcache=_fa2_kvcache,
)
except Exception:
return None
_fa3 = _load_flash_attention_3()
_fa2 = None if _fa3 is not None else _load_flash_attention_2()
HAS_FA2 = _fa2 is not None
HAS_FA3 = _fa3 is not None or HAS_FA2
# Override for testing: set to 'fa3', 'fa2', 'sdpa', or None (auto)
_override_impl = None
def _use_fa3():
"""Determine whether to use FA3."""
if _override_impl == 'fa3':
assert HAS_FA3, "Cannot override to FA3: not available on this hardware"
return True
if _override_impl in ('fa2', 'sdpa'):
return False
return HAS_FA3
def _use_fa2():
"""Determine whether to use FA2."""
if _override_impl == 'fa2':
assert HAS_FA2, "Cannot override to FA2: not available on this hardware"
return True
if _override_impl in ('fa3', 'sdpa'):
return False
return HAS_FA2 and not _use_fa3()
# =============================================================================
# SDPA helpers
# =============================================================================
def _sdpa_attention(q, k, v, window_size, enable_gqa):
"""
SDPA attention with sliding window support.
q, k, v are (B, H, T, D) format.
"""
Tq = q.size(2)
Tk = k.size(2)
window = window_size[0]
# Full context, same length
if (window < 0 or window >= Tq) and Tq == Tk:
return F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
# Single token generation
if Tq == 1:
if window >= 0 and window < Tk:
# Keep only the last (window + 1) key/value tokens
start = max(0, Tk - (window + 1))
k = k[:, :, start:, :]
v = v[:, :, start:, :]
return F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
# Need explicit mask for sliding window / chunk inference (Tq != Tk)
device = q.device
row_idx = (Tk - Tq) + torch.arange(Tq, device=device).unsqueeze(1)
col_idx = torch.arange(Tk, device=device).unsqueeze(0)
mask = col_idx <= row_idx
# Apply sliding window constraint
if window >= 0 and window < Tk:
mask = mask & ((row_idx - col_idx) <= window)
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
# =============================================================================
# Public API: Same interface as FA3
# =============================================================================
def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
"""
Flash Attention for training (no KV cache).
Args:
q, k, v: Tensors of shape (B, T, H, D) -- identical layout in FA2 and FA3
causal: Whether to use causal masking
window_size: (left, right) sliding window. -1 means unlimited.
Returns:
Output tensor of shape (B, T, H, D)
"""
if _use_fa3():
return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
# FA2: identical (B, T, H, D) layout and same window_size parameter
if _use_fa2():
return _fa2.flash_attn_func(q, k, v, dropout_p=0.0, causal=causal, window_size=window_size)
# SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
enable_gqa = q.size(1) != k.size(1)
y = _sdpa_attention(q, k, v, window_size, enable_gqa)
return y.transpose(1, 2) # back to (B, T, H, D)
def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=None,
causal=False, window_size=(-1, -1)):
"""
Flash Attention with KV cache for inference.
FA3/FA2 update k_cache/v_cache in-place. SDPA fallback does the same manually.
Args:
q: Queries, shape (B, T_new, H, D)
k_cache, v_cache: Pre-allocated cache tensors, shape (B, T_max, H_kv, D)
k, v: New keys/values to insert, shape (B, T_new, H_kv, D)
cache_seqlens: Current position in cache, shape (B,) int32
causal: Whether to use causal masking
window_size: (left, right) sliding window. -1 means unlimited.
Returns:
Output tensor of shape (B, T_new, H, D)
"""
if _use_fa3():
return _fa3.flash_attn_with_kvcache(
q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
causal=causal, window_size=window_size,
)
# FA2: identical API signature
if _use_fa2():
return _fa2.flash_attn_with_kvcache(
q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
causal=causal, window_size=window_size,
)
# SDPA fallback: manually manage KV cache (mirrors FA3/FA2 in-place update)
B, T_new, H, D = q.shape
pos = cache_seqlens[0].item() # assume uniform position across batch
# Insert new k, v into cache in-place
if k is not None and v is not None:
k_cache[:, pos:pos + T_new, :, :] = k
v_cache[:, pos:pos + T_new, :, :] = v
# Full cache up to current position + new tokens
end_pos = pos + T_new
k_full = k_cache[:, :end_pos, :, :]
v_full = v_cache[:, :end_pos, :, :]
# Transpose to SDPA layout: (B, T, H, D) -> (B, H, T, D)
q_sdpa = q.transpose(1, 2)
k_sdpa = k_full.transpose(1, 2)
v_sdpa = v_full.transpose(1, 2)
enable_gqa = q_sdpa.size(1) != k_sdpa.size(1)
y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa)
return y_sdpa.transpose(1, 2) # back to (B, T, H, D)
# =============================================================================
# Export: flash_attn module interface (drop-in replacement for FA3)
# =============================================================================
from types import SimpleNamespace
flash_attn = SimpleNamespace(
flash_attn_func=flash_attn_func,
flash_attn_with_kvcache=flash_attn_with_kvcache,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment