Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save johnnynunez/65d7830f232027b15de3bde2408037b4 to your computer and use it in GitHub Desktop.

Select an option

Save johnnynunez/65d7830f232027b15de3bde2408037b4 to your computer and use it in GitHub Desktop.

Install VLLM Spark

  1. Install uv
curl -LsSf https://astral.sh/uv/install.sh | sh
  1. Create environment
sudo apt install python3-dev
uv venv .vllm --python 3.12
source .vllm/bin/activate
  1. Install vllm
uv pip install -U vllm --torch-backend=auto --extra-index-url https://wheels.vllm.ai/nightly/cu130
uv pip install --prerelease=allow --force-reinstall triton --index-url https://download.pytorch.org/whl/test/cu132
  1. Export variables
export TORCH_CUDA_ARCH_LIST=12.1a
  1. Clean memory
sudo sysctl -w vm.drop_caches=3

Wiring flash-attn-4-sm120 into vLLM on SM120 / SM121

How to make vLLM use the SecondNatureComputing/flash-attn-4-sm120 Hugging Face kernel for the FA4 path on consumer Blackwell GPUs (RTX 5090, RTX PRO 6000 Blackwell, DGX Spark GB10 / SM121a).

Tested on:

  • DGX Spark (GB10), compute capability 12.1
  • vLLM installed in a venv at ~/Projects/vllm/.vllm
  • Python 3.12

Heads up before you start. On SM120/121 this kernel is not a free speedup. The HF README's own benchmarks show FA4 is 4–10% slower than vLLM's bundled FA2 on realistic Qwen3 prefill shapes, and ~1.7× slower on very short sequences (S = 128). Use this when you specifically need FA4-only features (paged KV with FA4, score_mod, block sparse, dropout in attention). For pure throughput, stay on FA2 and don't read further.


1. Why a shim is needed

vLLM in CUDA does not import the upstream flash_attn package. It uses its own bundled fork:

# vllm/v1/attention/backends/fa_utils.py
if current_platform.is_cuda():
    from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata

vllm.vllm_flash_attn calls compiled C++ extensions (_vllm_fa2_C, _vllm_fa3_C) and its own CuTe DSL implementation under vllm/vllm_flash_attn/cute/ for FA4. So aliasing sys.modules['flash_attn'] to the HF kernel does nothing for the main attention path.

vLLM already ships SM120 CuTe kernels (vllm/vllm_flash_attn/cute/flash_fwd_sm120.py), but in the version most people have installed they are gated:

# vllm/vllm_flash_attn/cute/interface.py
assert page_table is None,    "Paged KV not supported on SM 12.0 in this PR"
assert not is_split_kv,       "SplitKV not supported on SM 12.0 in this PR"

vLLM serves with paged KV always, so the bundled FA4 SM120 path can't run during serving. Net effect: on SM12x, vLLM falls back to FA2.

The HF kernel bundles two PRs that vLLM's bundled version is missing:

  • #2348 — SM120 kernel-level paged KV cache support
  • #2336 — SM120 split-KV (FlashDecoding)

So the only way to use FA4-with-paged-KV on SM12x today is to redirect vLLM's FA4 call site (vllm.vllm_flash_attn.cute.interface._flash_attn_fwd) to the HF kernel. That's what this shim does.


2. Prerequisites

# Inside your vLLM venv
uv pip install -U "kernels>=0.4" "nvidia-cutlass-dsl>=4.4.1" einops apache-tvm-ffi

Confirm hardware:

nvidia-smi --query-gpu=name,compute_cap --format=csv
# Expect compute_cap >= 12.0  (e.g. 12.1 on DGX Spark GB10)

CUDA Toolkit must be 12.8 or newer (FA4 baseline).


3. Pre-download the kernel

python -c "from kernels import get_kernel; get_kernel('SecondNatureComputing/flash-attn-4-sm120')"

Sanity check:

python - <<'PY'
import torch
from kernels import get_kernel
fa4 = get_kernel("SecondNatureComputing/flash-attn-4-sm120")
q = k = v = torch.randn(1, 1024, 16, 128, device="cuda", dtype=torch.bfloat16)
out, _ = fa4.flash_attn_func(q, k, v, causal=True)
print("OK", out.shape, out.dtype)
PY

4. Install the shim

The shim does three things, lazily, when the relevant vLLM modules first load:

  1. Patches vllm.vllm_flash_attn.flash_attn_interface._is_fa4_supported to accept SM 9.x / 10.x / 11.x / 12.x.
  2. Wraps vllm.vllm_flash_attn.cute.interface._flash_attn_fwd so that on SM12x the call dispatches to the HF kernel's _flash_attn_fwd (which has paged KV).
  3. Wraps vllm.v1.attention.backends.fa_utils.get_flash_attn_version so that on SM12x with head_dim ≤ 128 it returns 4 (instead of the default 2).

The wiring is done via a sys.meta_path finder that runs the patches the moment those vLLM modules finish loading — before any other module captures get_flash_attn_version by name.

Note on sitecustomize.py. Ubuntu ships /usr/lib/python3.12/sitecustomize.py which takes precedence over a venv-local one, so we use a .pth file (zzz_*.pth) instead. .pth files with import … lines are processed by site.py at interpreter startup, regardless of OS-level sitecustomize.

4.1 Create the shim file

SP="$VIRTUAL_ENV/lib/python3.12/site-packages"
# Remove any earlier experimental shim
rm -f "$SP/zzz_fa4_sm120_shim.pth" "$SP/fa4_sm120_shim.py" "$SP/sitecustomize.py"

cat > "$SP/fa4_sm120_shim.py" <<'PY'
"""
Force vLLM on SM12x (RTX 5090 / RTX PRO 6000 / DGX Spark) to use the
SecondNatureComputing/flash-attn-4-sm120 HF kernel for the FA4 path.

Disable by setting env var FA4_SM120_SHIM=0 before launching.
"""
from __future__ import annotations

import os
import sys
import warnings
from importlib.abc import Loader, MetaPathFinder

if os.environ.get("FA4_SM120_SHIM", "1") != "0":

    _HF_KERNEL = None

    def _hf_kernel():
        global _HF_KERNEL
        if _HF_KERNEL is None:
            from kernels import get_kernel
            _HF_KERNEL = get_kernel("SecondNatureComputing/flash-attn-4-sm120")
        return _HF_KERNEL

    _PATCHED: set[str] = set()

    def _is_sm12x() -> bool:
        try:
            import torch
            if not torch.cuda.is_available():
                return False
            major, _ = torch.cuda.get_device_capability()
            return major == 12
        except Exception:
            return False

    def _patch_fa_iface(mod):
        def _is_fa4_supported_patched():
            if not getattr(mod, "FA4_AVAILABLE", False):
                return False, getattr(mod, "FA4_UNAVAILABLE_REASON", "FA4 unavailable")
            try:
                import torch
                major, _ = torch.cuda.get_device_capability()
            except Exception:
                return False, "no CUDA device"
            if major in (9, 10, 11, 12):
                return True, None
            return False, f"FA4 not supported on capability {major}.x"
        mod._is_fa4_supported = _is_fa4_supported_patched

    def _patch_cute_iface(mod):
        orig = mod._flash_attn_fwd

        def _dispatch(*args, **kwargs):
            if _is_sm12x():
                return _hf_kernel().interface._flash_attn_fwd(*args, **kwargs)
            return orig(*args, **kwargs)

        mod._flash_attn_fwd = _dispatch

    def _patch_fa_utils(mod):
        import functools
        orig = mod.get_flash_attn_version

        @functools.wraps(orig)
        def patched(requires_alibi: bool = False,
                    head_size: int | None = None,
                    head_size_v: int | None = None,
                    has_sinks: bool = False):
            if not _is_sm12x():
                return orig(requires_alibi=requires_alibi,
                            head_size=head_size,
                            head_size_v=head_size_v,
                            has_sinks=has_sinks)
            if requires_alibi:
                return 2
            if head_size is not None and head_size > 128:
                return 2
            try:
                from vllm.vllm_flash_attn.flash_attn_interface import is_fa_version_supported
                if is_fa_version_supported(4):
                    return 4
            except Exception:
                pass
            return 2

        mod.get_flash_attn_version = patched

    _DISPATCH = {
        "vllm.vllm_flash_attn.flash_attn_interface": _patch_fa_iface,
        "vllm.vllm_flash_attn.cute.interface":      _patch_cute_iface,
        "vllm.v1.attention.backends.fa_utils":      _patch_fa_utils,
    }

    def _try_patch(name: str):
        if name in _PATCHED or name not in _DISPATCH:
            return
        mod = sys.modules.get(name)
        if mod is None:
            return
        try:
            _DISPATCH[name](mod)
            _PATCHED.add(name)
        except Exception as e:
            warnings.warn(f"[fa4_sm120_shim] failed to patch {name}: {e!r}")

    class _WrappedLoader(Loader):
        def __init__(self, real, name):
            self._real = real
            self._name = name

        def create_module(self, spec):
            if hasattr(self._real, "create_module"):
                return self._real.create_module(spec)
            return None

        def exec_module(self, module):
            self._real.exec_module(module)
            _try_patch(self._name)

    class _PatchingFinder(MetaPathFinder):
        def find_spec(self, name, path=None, target=None):
            if name not in _DISPATCH or name in _PATCHED:
                return None
            for finder in list(sys.meta_path):
                if finder is self or not hasattr(finder, "find_spec"):
                    continue
                spec = finder.find_spec(name, path, target)
                if spec is not None and spec.loader is not None:
                    spec.loader = _WrappedLoader(spec.loader, name)
                    return spec
            return None

    if not getattr(sys, "_fa4_sm120_shim_installed", False):
        sys.meta_path.insert(0, _PatchingFinder())
        try:
            sys._fa4_sm120_shim_installed = True  # type: ignore[attr-defined]
        except Exception:
            pass

    for _n in list(_DISPATCH):
        _try_patch(_n)
PY

4.2 Auto-load via a .pth file

cat > "$SP/zzz_fa4_sm120_shim.pth" <<'PTH'
import fa4_sm120_shim
PTH

.pth files starting with import … are executed by site.py at interpreter startup. The zzz_ prefix sorts our file last so it runs after torch, kernels, etc.


5. Verify

5.1 Hook is installed

python - <<'PY'
import sys
import fa4_sm120_shim
print("hook installed:", any("PatchingFinder" in type(f).__name__ for f in sys.meta_path))
PY

Expected: hook installed: True.

5.2 Patches applied to vLLM

python - <<'PY'
import vllm.vllm_flash_attn.flash_attn_interface as fai
import vllm.vllm_flash_attn.cute.interface as cute
import vllm.v1.attention.backends.fa_utils as fau

print("FA4_AVAILABLE:                       ", fai.FA4_AVAILABLE)
print("_is_fa4_supported():                  ", fai._is_fa4_supported())
print("get_flash_attn_version(head_size=128):", fau.get_flash_attn_version(head_size=128))
print("get_flash_attn_version(head_size=256):", fau.get_flash_attn_version(head_size=256))
print("_flash_attn_fwd qualname:             ", cute._flash_attn_fwd.__qualname__)
PY

Expected:

FA4_AVAILABLE:                        True
_is_fa4_supported():                  (True, None)
get_flash_attn_version(head_size=128): 4
get_flash_attn_version(head_size=256): 2
_flash_attn_fwd qualname:              _patch_cute_iface.<locals>._dispatch

If _flash_attn_fwd qualname shows the original (something like cute.interface._flash_attn_fwd), the hook didn't run for that module — see "Troubleshooting".

5.3 End-to-end FA4 fwd through vLLM's call site

This is closer to what vLLM actually invokes:

python - <<'PY'
import torch
from vllm.vllm_flash_attn.cute.interface import _flash_attn_fwd

B, S, Hq, Hkv, D = 1, 1024, 16, 8, 128
q = torch.randn(B*S, Hq,  D, device="cuda", dtype=torch.bfloat16)
k = torch.randn(B*S, Hkv, D, device="cuda", dtype=torch.bfloat16)
v = torch.randn(B*S, Hkv, D, device="cuda", dtype=torch.bfloat16)
cu = torch.tensor([0, B*S], device="cuda", dtype=torch.int32)

out, lse = _flash_attn_fwd(
    q, k, v,
    cu_seqlens_q=cu, cu_seqlens_k=cu,
    max_seqlen_q=B*S, max_seqlen_k=B*S,
    softmax_scale=D**-0.5, causal=True, return_lse=True,
)
print("OK", out.shape, out.dtype, "lse:", None if lse is None else lse.shape)
PY

If this completes without error, the HF kernel is being driven through vLLM's FA4 entry point.


6. Launch vLLM

Run with debug logging the first time so you can verify the routing:

VLLM_LOGGING_LEVEL=INFO vllm serve Qwen/Qwen3.5-27B \
  --speculative-config '{"method": "dflash", "model": "z-lab/Qwen3.5-27B-DFlash", "num_speculative_tokens": 15}' \
  --attention-backend flash_attn \
  --gpu-memory-utilization 0.85 \
  --max-model-len 65536 \
  --load-format fastsafetensors \
  --max-num-batched-tokens 32768 2>&1 | tee vllm.log

In another terminal:

grep -iE "fa version|flash_attn|attention backend|sm12|fa4" vllm.log | head -50

Look for log lines mentioning fa_version=4 or the chosen attention backend.


7. Disable the shim

Without removing files:

FA4_SM120_SHIM=0 vllm serve ...

Permanently:

SP="$VIRTUAL_ENV/lib/python3.12/site-packages"
rm -f "$SP/fa4_sm120_shim.py" "$SP/zzz_fa4_sm120_shim.pth"

8. Troubleshooting

_flash_attn_fwd qualname still shows the original

Most likely the module was imported in some other Python process or by some import path that loaded it before the finder was active. Make sure:

  • You created zzz_fa4_sm120_shim.pth (the zzz_ prefix matters for ordering).
  • You don't have a competing .pth that imports vLLM modules earlier.
  • The shim file is in the venv's site-packages, not the system one.

Sanity:

python -c "import sys; print('\n'.join(p for p in sys.path if 'site-packages' in p))"
ls "$VIRTUAL_ENV/lib/python3.12/site-packages/" | grep fa4_sm120

assert is_fa_version_supported(4) fires somewhere

We already patch _is_fa4_supported, but if some code path imported is_fa_version_supported before the patch ran and captured it by name, it might still see the unpatched version. Reproduce with the verifier in §5.2 and pinpoint which module — open an issue or extend _DISPATCH to also patch the offending module.

Tensor-shape / stride asserts inside the HF kernel

Possible signature drift between vLLM's call site and the HF kernel's _flash_attn_fwd. Compare:

  • vLLM call: vllm/vllm_flash_attn/flash_attn_interface.py around elif fa_version == 4:
  • HF _flash_attn_fwd: ~/.cache/huggingface/hub/models--SecondNatureComputing--flash-attn-4-sm120/snapshots/<hash>/build/torch-cuda/interface.py

If a kwarg vLLM passes isn't accepted by the HF kernel, drop it inside _dispatch before forwarding. As of HF kernel v0.1.0 all kwargs vLLM passes (cu_seqlens_q/k, seqused_k, max_seqlen_q/k, page_table, softmax_scale, causal, softcap, window_size_left/right, num_splits, return_lse, out, learnable_sink) are present.

head_dim > 128 model

The SM120 kernel cannot fit head_dim > 128 in 99 KB SMEM. The shim's get_flash_attn_version already returns 2 in that case, so vLLM falls back to FA2. If the model has head_dim == 256, don't bother with this shim — neither this kernel nor vLLM's bundled FA4 will run, and vLLM's FA2 already handles it.

Check head_dim:

python - <<'PY'
from huggingface_hub import hf_hub_download
import json
p = hf_hub_download("Qwen/Qwen3.5-27B", "config.json")
c = json.load(open(p))
print("head_dim:", c.get("head_dim") or (c["hidden_size"] // c["num_attention_heads"]))
PY

Paged KV not supported on SM 12.0 in this PR

This assert is in vLLM's bundled cute/interface.py. If you see it, the dispatch didn't reach our wrapper. Re-run the §5.2 verifier; if _flash_attn_fwd qualname doesn't say _dispatch, the patch is not in effect.


9. How it works (under the hood)

.vllm/lib/python3.12/site-packages/
├── zzz_fa4_sm120_shim.pth          ← processed by site.py at startup; runs `import fa4_sm120_shim`
├── fa4_sm120_shim.py               ← installs sys.meta_path finder, patches three vLLM modules
│
└── vllm/
    ├── v1/attention/backends/
    │   └── fa_utils.py             ← get_flash_attn_version()  — patched: returns 4 on SM12x
    └── vllm_flash_attn/
        ├── flash_attn_interface.py ← _is_fa4_supported()       — patched: accepts SM 9-12.x
        │   └── flash_attn_varlen_func — dispatches on fa_version, calls _flash_attn_fwd if 4
        └── cute/
            └── interface.py        ← _flash_attn_fwd()         — patched: dispatches to HF kernel on SM12x

~/.cache/huggingface/hub/models--SecondNatureComputing--flash-attn-4-sm120/
└── snapshots/<hash>/build/torch-cuda/
    └── interface.py                ← real _flash_attn_fwd with paged KV (PR #2348)

Flow at request time:

  1. Some vLLM backend (vllm/v1/attention/backends/flash_attn.py) calls get_flash_attn_version(head_size=128).
  2. Patched version returns 4 (instead of 2).
  3. The backend calls flash_attn_varlen_func(..., fa_version=4).
  4. Inside, the FA4 branch does from vllm.vllm_flash_attn.cute.interface import _flash_attn_fwd.
  5. Patched _flash_attn_fwd is a _dispatch wrapper. On SM12x it forwards to hf_kernel.interface._flash_attn_fwd, otherwise it calls the original.
  6. The HF kernel runs, returns (out, softmax_lse).

11. Caveats

  • The shim assumes vLLM's internal API names (_flash_attn_fwd, is_fa_version_supported, get_flash_attn_version). If you upgrade vLLM and these change, the shim's verifier in §5.2 will tell you immediately.
  • On SM12x, the HF kernel's interface.py clamps num_splits to 1 (no SplitKV). Decode workloads use a single split anyway, but if you've tuned vLLM with num_splits > 1 it'll be silently ignored.
  • Dropout in this kernel falls back to smaller tiles to avoid register spills. Throughput cost is small but real.
  • Backward pass on SM12x has its own restrictions (no block sparse, no score_mod, no mask_mod, no deterministic). Inference is forward-only so this doesn't affect serving.

License

The HF kernel is BSD-3-Clause (inherited from Dao-AILab/flash-attention). The shim above is provided as-is, no warranty.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment