Install VLLM Spark
- Install uv
curl -LsSf https://astral.sh/uv/install.sh | sh- Create environment
sudo apt install python3-dev
uv venv .vllm --python 3.12
source .vllm/bin/activate- Install vllm
uv pip install -U vllm --torch-backend=auto --extra-index-url https://wheels.vllm.ai/nightly/cu130
uv pip install --prerelease=allow --force-reinstall triton --index-url https://download.pytorch.org/whl/test/cu132
- Export variables
export TORCH_CUDA_ARCH_LIST=12.1a- Clean memory
sudo sysctl -w vm.drop_caches=3How to make vLLM use the SecondNatureComputing/flash-attn-4-sm120 Hugging Face kernel for the FA4 path on consumer Blackwell GPUs (RTX 5090, RTX PRO 6000 Blackwell, DGX Spark GB10 / SM121a).
Tested on:
- DGX Spark (GB10), compute capability 12.1
- vLLM installed in a venv at
~/Projects/vllm/.vllm - Python 3.12
Heads up before you start. On SM120/121 this kernel is not a free speedup. The HF README's own benchmarks show FA4 is 4–10% slower than vLLM's bundled FA2 on realistic Qwen3 prefill shapes, and ~1.7× slower on very short sequences (S = 128). Use this when you specifically need FA4-only features (paged KV with FA4,
score_mod, block sparse, dropout in attention). For pure throughput, stay on FA2 and don't read further.
vLLM in CUDA does not import the upstream flash_attn package. It uses its own bundled fork:
# vllm/v1/attention/backends/fa_utils.py
if current_platform.is_cuda():
from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadatavllm.vllm_flash_attn calls compiled C++ extensions (_vllm_fa2_C, _vllm_fa3_C) and its own CuTe DSL implementation under vllm/vllm_flash_attn/cute/ for FA4. So aliasing sys.modules['flash_attn'] to the HF kernel does nothing for the main attention path.
vLLM already ships SM120 CuTe kernels (vllm/vllm_flash_attn/cute/flash_fwd_sm120.py), but in the version most people have installed they are gated:
# vllm/vllm_flash_attn/cute/interface.py
assert page_table is None, "Paged KV not supported on SM 12.0 in this PR"
assert not is_split_kv, "SplitKV not supported on SM 12.0 in this PR"vLLM serves with paged KV always, so the bundled FA4 SM120 path can't run during serving. Net effect: on SM12x, vLLM falls back to FA2.
The HF kernel bundles two PRs that vLLM's bundled version is missing:
- #2348 — SM120 kernel-level paged KV cache support
- #2336 — SM120 split-KV (FlashDecoding)
So the only way to use FA4-with-paged-KV on SM12x today is to redirect vLLM's FA4 call site (vllm.vllm_flash_attn.cute.interface._flash_attn_fwd) to the HF kernel. That's what this shim does.
# Inside your vLLM venv
uv pip install -U "kernels>=0.4" "nvidia-cutlass-dsl>=4.4.1" einops apache-tvm-ffiConfirm hardware:
nvidia-smi --query-gpu=name,compute_cap --format=csv
# Expect compute_cap >= 12.0 (e.g. 12.1 on DGX Spark GB10)CUDA Toolkit must be 12.8 or newer (FA4 baseline).
python -c "from kernels import get_kernel; get_kernel('SecondNatureComputing/flash-attn-4-sm120')"Sanity check:
python - <<'PY'
import torch
from kernels import get_kernel
fa4 = get_kernel("SecondNatureComputing/flash-attn-4-sm120")
q = k = v = torch.randn(1, 1024, 16, 128, device="cuda", dtype=torch.bfloat16)
out, _ = fa4.flash_attn_func(q, k, v, causal=True)
print("OK", out.shape, out.dtype)
PYThe shim does three things, lazily, when the relevant vLLM modules first load:
- Patches
vllm.vllm_flash_attn.flash_attn_interface._is_fa4_supportedto accept SM 9.x / 10.x / 11.x / 12.x. - Wraps
vllm.vllm_flash_attn.cute.interface._flash_attn_fwdso that on SM12x the call dispatches to the HF kernel's_flash_attn_fwd(which has paged KV). - Wraps
vllm.v1.attention.backends.fa_utils.get_flash_attn_versionso that on SM12x withhead_dim ≤ 128it returns4(instead of the default2).
The wiring is done via a sys.meta_path finder that runs the patches the moment those vLLM modules finish loading — before any other module captures get_flash_attn_version by name.
Note on
sitecustomize.py. Ubuntu ships/usr/lib/python3.12/sitecustomize.pywhich takes precedence over a venv-local one, so we use a.pthfile (zzz_*.pth) instead..pthfiles withimport …lines are processed bysite.pyat interpreter startup, regardless of OS-levelsitecustomize.
SP="$VIRTUAL_ENV/lib/python3.12/site-packages"
# Remove any earlier experimental shim
rm -f "$SP/zzz_fa4_sm120_shim.pth" "$SP/fa4_sm120_shim.py" "$SP/sitecustomize.py"
cat > "$SP/fa4_sm120_shim.py" <<'PY'
"""
Force vLLM on SM12x (RTX 5090 / RTX PRO 6000 / DGX Spark) to use the
SecondNatureComputing/flash-attn-4-sm120 HF kernel for the FA4 path.
Disable by setting env var FA4_SM120_SHIM=0 before launching.
"""
from __future__ import annotations
import os
import sys
import warnings
from importlib.abc import Loader, MetaPathFinder
if os.environ.get("FA4_SM120_SHIM", "1") != "0":
_HF_KERNEL = None
def _hf_kernel():
global _HF_KERNEL
if _HF_KERNEL is None:
from kernels import get_kernel
_HF_KERNEL = get_kernel("SecondNatureComputing/flash-attn-4-sm120")
return _HF_KERNEL
_PATCHED: set[str] = set()
def _is_sm12x() -> bool:
try:
import torch
if not torch.cuda.is_available():
return False
major, _ = torch.cuda.get_device_capability()
return major == 12
except Exception:
return False
def _patch_fa_iface(mod):
def _is_fa4_supported_patched():
if not getattr(mod, "FA4_AVAILABLE", False):
return False, getattr(mod, "FA4_UNAVAILABLE_REASON", "FA4 unavailable")
try:
import torch
major, _ = torch.cuda.get_device_capability()
except Exception:
return False, "no CUDA device"
if major in (9, 10, 11, 12):
return True, None
return False, f"FA4 not supported on capability {major}.x"
mod._is_fa4_supported = _is_fa4_supported_patched
def _patch_cute_iface(mod):
orig = mod._flash_attn_fwd
def _dispatch(*args, **kwargs):
if _is_sm12x():
return _hf_kernel().interface._flash_attn_fwd(*args, **kwargs)
return orig(*args, **kwargs)
mod._flash_attn_fwd = _dispatch
def _patch_fa_utils(mod):
import functools
orig = mod.get_flash_attn_version
@functools.wraps(orig)
def patched(requires_alibi: bool = False,
head_size: int | None = None,
head_size_v: int | None = None,
has_sinks: bool = False):
if not _is_sm12x():
return orig(requires_alibi=requires_alibi,
head_size=head_size,
head_size_v=head_size_v,
has_sinks=has_sinks)
if requires_alibi:
return 2
if head_size is not None and head_size > 128:
return 2
try:
from vllm.vllm_flash_attn.flash_attn_interface import is_fa_version_supported
if is_fa_version_supported(4):
return 4
except Exception:
pass
return 2
mod.get_flash_attn_version = patched
_DISPATCH = {
"vllm.vllm_flash_attn.flash_attn_interface": _patch_fa_iface,
"vllm.vllm_flash_attn.cute.interface": _patch_cute_iface,
"vllm.v1.attention.backends.fa_utils": _patch_fa_utils,
}
def _try_patch(name: str):
if name in _PATCHED or name not in _DISPATCH:
return
mod = sys.modules.get(name)
if mod is None:
return
try:
_DISPATCH[name](mod)
_PATCHED.add(name)
except Exception as e:
warnings.warn(f"[fa4_sm120_shim] failed to patch {name}: {e!r}")
class _WrappedLoader(Loader):
def __init__(self, real, name):
self._real = real
self._name = name
def create_module(self, spec):
if hasattr(self._real, "create_module"):
return self._real.create_module(spec)
return None
def exec_module(self, module):
self._real.exec_module(module)
_try_patch(self._name)
class _PatchingFinder(MetaPathFinder):
def find_spec(self, name, path=None, target=None):
if name not in _DISPATCH or name in _PATCHED:
return None
for finder in list(sys.meta_path):
if finder is self or not hasattr(finder, "find_spec"):
continue
spec = finder.find_spec(name, path, target)
if spec is not None and spec.loader is not None:
spec.loader = _WrappedLoader(spec.loader, name)
return spec
return None
if not getattr(sys, "_fa4_sm120_shim_installed", False):
sys.meta_path.insert(0, _PatchingFinder())
try:
sys._fa4_sm120_shim_installed = True # type: ignore[attr-defined]
except Exception:
pass
for _n in list(_DISPATCH):
_try_patch(_n)
PYcat > "$SP/zzz_fa4_sm120_shim.pth" <<'PTH'
import fa4_sm120_shim
PTH.pth files starting with import … are executed by site.py at interpreter startup. The zzz_ prefix sorts our file last so it runs after torch, kernels, etc.
python - <<'PY'
import sys
import fa4_sm120_shim
print("hook installed:", any("PatchingFinder" in type(f).__name__ for f in sys.meta_path))
PYExpected: hook installed: True.
python - <<'PY'
import vllm.vllm_flash_attn.flash_attn_interface as fai
import vllm.vllm_flash_attn.cute.interface as cute
import vllm.v1.attention.backends.fa_utils as fau
print("FA4_AVAILABLE: ", fai.FA4_AVAILABLE)
print("_is_fa4_supported(): ", fai._is_fa4_supported())
print("get_flash_attn_version(head_size=128):", fau.get_flash_attn_version(head_size=128))
print("get_flash_attn_version(head_size=256):", fau.get_flash_attn_version(head_size=256))
print("_flash_attn_fwd qualname: ", cute._flash_attn_fwd.__qualname__)
PYExpected:
FA4_AVAILABLE: True
_is_fa4_supported(): (True, None)
get_flash_attn_version(head_size=128): 4
get_flash_attn_version(head_size=256): 2
_flash_attn_fwd qualname: _patch_cute_iface.<locals>._dispatch
If _flash_attn_fwd qualname shows the original (something like cute.interface._flash_attn_fwd), the hook didn't run for that module — see "Troubleshooting".
This is closer to what vLLM actually invokes:
python - <<'PY'
import torch
from vllm.vllm_flash_attn.cute.interface import _flash_attn_fwd
B, S, Hq, Hkv, D = 1, 1024, 16, 8, 128
q = torch.randn(B*S, Hq, D, device="cuda", dtype=torch.bfloat16)
k = torch.randn(B*S, Hkv, D, device="cuda", dtype=torch.bfloat16)
v = torch.randn(B*S, Hkv, D, device="cuda", dtype=torch.bfloat16)
cu = torch.tensor([0, B*S], device="cuda", dtype=torch.int32)
out, lse = _flash_attn_fwd(
q, k, v,
cu_seqlens_q=cu, cu_seqlens_k=cu,
max_seqlen_q=B*S, max_seqlen_k=B*S,
softmax_scale=D**-0.5, causal=True, return_lse=True,
)
print("OK", out.shape, out.dtype, "lse:", None if lse is None else lse.shape)
PYIf this completes without error, the HF kernel is being driven through vLLM's FA4 entry point.
Run with debug logging the first time so you can verify the routing:
VLLM_LOGGING_LEVEL=INFO vllm serve Qwen/Qwen3.5-27B \
--speculative-config '{"method": "dflash", "model": "z-lab/Qwen3.5-27B-DFlash", "num_speculative_tokens": 15}' \
--attention-backend flash_attn \
--gpu-memory-utilization 0.85 \
--max-model-len 65536 \
--load-format fastsafetensors \
--max-num-batched-tokens 32768 2>&1 | tee vllm.logIn another terminal:
grep -iE "fa version|flash_attn|attention backend|sm12|fa4" vllm.log | head -50Look for log lines mentioning fa_version=4 or the chosen attention backend.
Without removing files:
FA4_SM120_SHIM=0 vllm serve ...Permanently:
SP="$VIRTUAL_ENV/lib/python3.12/site-packages"
rm -f "$SP/fa4_sm120_shim.py" "$SP/zzz_fa4_sm120_shim.pth"Most likely the module was imported in some other Python process or by some import path that loaded it before the finder was active. Make sure:
- You created
zzz_fa4_sm120_shim.pth(thezzz_prefix matters for ordering). - You don't have a competing
.pththat imports vLLM modules earlier. - The shim file is in the venv's site-packages, not the system one.
Sanity:
python -c "import sys; print('\n'.join(p for p in sys.path if 'site-packages' in p))"
ls "$VIRTUAL_ENV/lib/python3.12/site-packages/" | grep fa4_sm120We already patch _is_fa4_supported, but if some code path imported is_fa_version_supported before the patch ran and captured it by name, it might still see the unpatched version. Reproduce with the verifier in §5.2 and pinpoint which module — open an issue or extend _DISPATCH to also patch the offending module.
Possible signature drift between vLLM's call site and the HF kernel's _flash_attn_fwd. Compare:
- vLLM call:
vllm/vllm_flash_attn/flash_attn_interface.pyaroundelif fa_version == 4: - HF
_flash_attn_fwd:~/.cache/huggingface/hub/models--SecondNatureComputing--flash-attn-4-sm120/snapshots/<hash>/build/torch-cuda/interface.py
If a kwarg vLLM passes isn't accepted by the HF kernel, drop it inside _dispatch before forwarding. As of HF kernel v0.1.0 all kwargs vLLM passes (cu_seqlens_q/k, seqused_k, max_seqlen_q/k, page_table, softmax_scale, causal, softcap, window_size_left/right, num_splits, return_lse, out, learnable_sink) are present.
The SM120 kernel cannot fit head_dim > 128 in 99 KB SMEM. The shim's get_flash_attn_version already returns 2 in that case, so vLLM falls back to FA2. If the model has head_dim == 256, don't bother with this shim — neither this kernel nor vLLM's bundled FA4 will run, and vLLM's FA2 already handles it.
Check head_dim:
python - <<'PY'
from huggingface_hub import hf_hub_download
import json
p = hf_hub_download("Qwen/Qwen3.5-27B", "config.json")
c = json.load(open(p))
print("head_dim:", c.get("head_dim") or (c["hidden_size"] // c["num_attention_heads"]))
PYThis assert is in vLLM's bundled cute/interface.py. If you see it, the dispatch didn't reach our wrapper. Re-run the §5.2 verifier; if _flash_attn_fwd qualname doesn't say _dispatch, the patch is not in effect.
.vllm/lib/python3.12/site-packages/
├── zzz_fa4_sm120_shim.pth ← processed by site.py at startup; runs `import fa4_sm120_shim`
├── fa4_sm120_shim.py ← installs sys.meta_path finder, patches three vLLM modules
│
└── vllm/
├── v1/attention/backends/
│ └── fa_utils.py ← get_flash_attn_version() — patched: returns 4 on SM12x
└── vllm_flash_attn/
├── flash_attn_interface.py ← _is_fa4_supported() — patched: accepts SM 9-12.x
│ └── flash_attn_varlen_func — dispatches on fa_version, calls _flash_attn_fwd if 4
└── cute/
└── interface.py ← _flash_attn_fwd() — patched: dispatches to HF kernel on SM12x
~/.cache/huggingface/hub/models--SecondNatureComputing--flash-attn-4-sm120/
└── snapshots/<hash>/build/torch-cuda/
└── interface.py ← real _flash_attn_fwd with paged KV (PR #2348)
Flow at request time:
- Some vLLM backend (
vllm/v1/attention/backends/flash_attn.py) callsget_flash_attn_version(head_size=128). - Patched version returns
4(instead of2). - The backend calls
flash_attn_varlen_func(..., fa_version=4). - Inside, the FA4 branch does
from vllm.vllm_flash_attn.cute.interface import _flash_attn_fwd. - Patched
_flash_attn_fwdis a_dispatchwrapper. On SM12x it forwards tohf_kernel.interface._flash_attn_fwd, otherwise it calls the original. - The HF kernel runs, returns
(out, softmax_lse).
- The shim assumes vLLM's internal API names (
_flash_attn_fwd,is_fa_version_supported,get_flash_attn_version). If you upgrade vLLM and these change, the shim's verifier in §5.2 will tell you immediately. - On SM12x, the HF kernel's
interface.pyclampsnum_splitsto 1 (no SplitKV). Decode workloads use a single split anyway, but if you've tuned vLLM withnum_splits > 1it'll be silently ignored. - Dropout in this kernel falls back to smaller tiles to avoid register spills. Throughput cost is small but real.
- Backward pass on SM12x has its own restrictions (no block sparse, no
score_mod, nomask_mod, no deterministic). Inference is forward-only so this doesn't affect serving.
The HF kernel is BSD-3-Clause (inherited from Dao-AILab/flash-attention). The shim above is provided as-is, no warranty.