Skip to content

Instantly share code, notes, and snippets.

@smartexpert
Last active May 2, 2026 19:40
Show Gist options
  • Select an option

  • Save smartexpert/d4c2f5123e1c7a6db438268f4bc4ec13 to your computer and use it in GitHub Desktop.

Select an option

Save smartexpert/d4c2f5123e1c7a6db438268f4bc4ec13 to your computer and use it in GitHub Desktop.
llama.cpp SOFT_MAX failed: invalid argument on Blackwell consumer (sm_120) — root cause and 12-line fix

llama.cpp SOFT_MAX failed: invalid argument on Blackwell consumer (sm_120) — root cause + 12-line fix

If you searched for this error, you're in the right place:

ggml_cuda_compute_forward: SOFT_MAX failed
CUDA error: invalid argument
  current device: 0, in function ggml_cuda_compute_forward at ggml/src/ggml-cuda/ggml-cuda.cu:2962

It crashes during the first prompt-processing batch. Mixture-of-Experts models (Qwen3-30B-A3B, Qwen3.6-35B-A3B, GPT-OSS, etc.) hit it immediately. Dense models may appear to "work" because they don't always exercise the affected code path early, but they're vulnerable too.

Affected setup

  • GPU: any Blackwell consumer card (sm_120 / compute capability 12.0). Verified on RTX 5070 Ti 16 GB.
  • Driver: 595.58.03 (CUDA 13.2 runtime advertised); reproduces on other 5xx-series drivers too.
  • CUDA toolkit (build): 13.2 (also confirmed on 13.1).
  • CUDA runtime (load): distro-shipped libcudart.so.12 (e.g. Ubuntu's nvidia-cuda-toolkit package).
  • llama.cpp: master HEAD b97ebdc9 and tag b9002. Bug exists for any commit since the dynamic shared-memory rework in PR #14497.

The combination that triggers it: build with CUDA 13 headers but link/load against libcudart.so.12 at runtime. This happens routinely on Linux distros that ship a CUDA-12 dev package alongside a self-installed CUDA 13 toolkit — the linker picks up the system libcudart.so first.

ldd of the offending binary tells you immediately:

$ ldd build/bin/libggml-cuda.so | grep cudart
        libcudart.so.12 => /usr/lib/x86_64-linux-gnu/libcudart.so.12   # ← BUG TRIGGER

If you see libcudart.so.13 instead, you don't have this bug.

Root cause

The cudaDeviceProp struct grew between CUDA 12 and 13. Newer fields like sharedMemPerBlockOptin and sharedMemPerMultiprocessor sit at offsets in the CUDA-13 header layout that the CUDA-12 runtime never writes. When cudaGetDeviceProperties runs against the older runtime, those fields read back as uninitialized garbage.

On the affected setups, prop.sharedMemPerBlockOptin reads back as 4294967297 (= 0x1_00000001). The standalone CUDA test below — same toolchain, same driver, same GPU — reads it correctly as 101376 via cudaDeviceGetAttribute, proving the runtime knows the right value; only the struct read is broken.

What llama.cpp does next:

// ggml/src/ggml-cuda/ggml-cuda.cu, ggml_cuda_init()
info.devices[id].smpbo = prop.sharedMemPerBlockOptin;    // reads garbage 0x100000001
...
// ggml/src/ggml-cuda/softmax.cu, launch_soft_max_kernels()
const size_t smpbo = ggml_cuda_info().devices[id].smpbo; // 0x100000001
CUDA_CHECK(cudaFuncSetAttribute(kernel,
    cudaFuncAttributeMaxDynamicSharedMemorySize, (int) smpbo)); // ← cast truncates to 1

The (int) cast truncates to 1. The kernel's max dynamic shared memory is now capped at 1 byte. The very first softmax launch that requests 640 bytes of dynamic shared memory fails with cudaErrorInvalidValue. The error sticks; ggml_cuda_compute_forward reports it on the next CUDA-API check, which happens to be the SOFT_MAX op — hence the misleading error message.

This is why MoE crashes immediately and dense models seem to "work": MoE graphs hit a small softmax (e.g. ncols=128 over experts, even when fused-topk-moe handles routing) and also hit attention softmax with mask-broadcast layouts that opt into shared memory. Dense models exercise fewer paths early.

Diagnostic — confirm in 30 seconds

This standalone CUDA program prints the relevant device-prop fields plus the same values fetched via the version-stable attribute API. Run it on the affected box:

// cuda_prop_test.cu — compile: /usr/local/cuda-13.x/bin/nvcc -arch=sm_120 -o cuda_prop_test cuda_prop_test.cu
#include <cstdio>
#include <cuda_runtime.h>
int main() {
    cudaDeviceProp p;
    cudaGetDeviceProperties(&p, 0);
    printf("sharedMemPerBlock          = %zu\n",        p.sharedMemPerBlock);
    printf("sharedMemPerBlockOptin     = %zu (0x%zx)\n", p.sharedMemPerBlockOptin,
                                                         p.sharedMemPerBlockOptin);
    printf("sharedMemPerMultiprocessor = %zu\n",        p.sharedMemPerMultiprocessor);
    int v;
    cudaDeviceGetAttribute(&v, cudaDevAttrMaxSharedMemoryPerBlockOptin, 0);
    printf("attr.MaxSharedMemoryPerBlockOptin = %d\n", v);
    return 0;
}

If prop.sharedMemPerBlockOptin and attr.MaxSharedMemoryPerBlockOptin disagree (e.g. 4294967297 vs 101376), you have the ABI-mismatch bug.

Also instructive — run llama.cpp with these env vars set, and you'll see exactly which softmax launch trips the cap:

GGML_SOFTMAX_DIAG=1 ./llama-server ...
# (only available if you apply the diagnostic patch from this gist's commentary)

The fix (12 lines)

Switch the affected reads from the cudaDeviceProp struct fields to cudaDeviceGetAttribute, which is ABI-stable across cudart versions and immune to header/runtime layout drift.

--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -281,7 +281,18 @@ static ggml_cuda_device_info ggml_cuda_init() {
                       id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no",
                       (size_t)(prop.totalGlobalMem / (1024 * 1024)));
 #else
-        info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
+        // Read shared-memory limits via cudaDeviceGetAttribute rather than the cudaDeviceProp
+        // struct fields. The struct layout can differ between the cudart headers used at compile
+        // time and the libcudart.so loaded at runtime (e.g. building with CUDA 13 headers but
+        // dlopen'ing libcudart.so.12 from a distro package). When that happens, fields appended
+        // by newer headers (sharedMemPerBlockOptin, sharedMemPerMultiprocessor, etc.) read back
+        // as uninitialized garbage, which silently breaks every kernel that opts into >48 KB of
+        // dynamic shared memory. The attribute API is ABI-stable and immune to this mismatch.
+        {
+            int v_smpbo = 0;
+            CUDA_CHECK(cudaDeviceGetAttribute(&v_smpbo, cudaDevAttrMaxSharedMemoryPerBlockOptin, id));
+            info.devices[id].smpbo = (size_t) v_smpbo;
+        }
         info.devices[id].cc = 100*prop.major + 10*prop.minor;
         GGML_LOG_INFO("  Device %d: %s, compute capability %d.%d, VMM: %s, VRAM: %zu MiB\n",
                       id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no",

Save as sm120_smpbo_fix.patch, then:

cd llama.cpp
git apply sm120_smpbo_fix.patch
cmake --build build -j --target llama-server llama-cli

Verification

After the fix, on RTX 5070 Ti, master HEAD b97ebdc9, CUDA 13.2 toolkit, libcudart.so.12 still loaded:

Model Quant Offload Before fix After fix
Llama-3.2-1B-Instruct (dense) Q4_K_M full GPU works PASS, "Paris"
Qwen3-30B-A3B-Instruct-2507 (MoE) Q3_K_M -ot experts→CPU crash PASS, "Paris", ~80 TPS
Qwen3.6-35B-A3B (MoE) IQ4_XS -ot experts→CPU crash PASS, coherent output

For the IQ4_XS case I additionally compared GPU vs CPU output at temperature=0, seed=7, on multiple prompts (factual recall, arithmetic, translation, multi-sentence explanation). Both produce coherent, well-formed text answering the prompt. They diverge at the token level after a few hundred characters, which is normal floating-point non-associativity between CPU and CUDA matmul/softmax — not corruption. No prompts produced gibberish, repetition loops, or non-language output.

Second bug — IQ1_S / IQ2_S / IQ3_S MUL_MAT broken on sm_120 (separate from SOFT_MAX)

While validating the SOFT_MAX fix on a real-world MoE workload (Qwen3.6-35B-A3B-UD-IQ4_XS.gguf), I hit a second, independent Blackwell-consumer bug. Your IQ-quant output may still be garbage even after the SOFT_MAX patch if your model contains tensors of these types.

Symptom

Coherent on short prompts, garbage on multi-token output. With the Unsloth UD-IQ4_XS variant of Qwen3.6-35B-A3B (despite the name, it has 80 tensors in IQ3_S, not IQ4_XS), prompts produced things like:

prompt: "Capital of France?"   →  "2022-02-22 14:38:25"
prompt: "translate to French"  →  Chinese gibberish

Same garbage with GGML_CUDA_FORCE_CUBLAS=1 (rules out MMQ as the sole cause).

Diagnosis

test-backend-ops confirms it cleanly. From build/bin/test-backend-ops -b CUDA0 test -o MUL_MAT_ID:

[MUL_MAT_ID] ERR = 0.957  iq2_s,n=32     FAIL
[MUL_MAT_ID] ERR = 1.015  iq1_s,n=1      FAIL
[MUL_MAT_ID] ERR = 0.974  iq1_s,n=32     FAIL
[MUL_MAT_ID] ERR = 0.912  iq3_s,n=32     FAIL
[MUL_MAT_ID] ERR = 0.604  iq3_s,n=1      FAIL
[MUL_MAT_ID] ERR = 0.512  iq2_s,n=1      FAIL

iq2_xxs, iq3_xxs, iq1_m, iq4_nl, iq4_xs   →  OK

ERR ~0.5–1.0 means the matmul output is essentially noise relative to the reference. Both regular MUL_MAT and MUL_MAT_ID fail for IQ1_S, IQ2_S, IQ3_S. The "_xxs" / "_xs" / "_nl" variants of similar bit-widths work correctly. K-quants are unaffected.

Watch out for misleading filenames

A model named *-IQ4_XS.gguf may not be uniformly IQ4_XS. The Unsloth UD ("Unsloth Dynamic") variants are mixed-precision. To audit yours:

# requires `numpy` and `pyyaml` (uv run --with numpy --with pyyaml ...)
import sys; sys.path.insert(0, "gguf-py")
import gguf
from collections import Counter
r = gguf.GGUFReader("/path/to/model.gguf")
c = Counter()
for t in r.tensors: c[t.tensor_type.name] += 1
for k, v in sorted(c.items(), key=lambda x: -x[1]):
    print(f"  {v:5d}  {k}")

Example output for Qwen3.6-35B-A3B-UD-IQ4_XS.gguf:

    361  F32        ← norms etc.
    251  Q8_0       ← attention weights (works)
     80  IQ3_S      ← ffn_gate_exps, ffn_up_exps  (BROKEN on sm_120)
     37  IQ4_XS     ← ffn_down_exps  (works)
      4  Q6_K       ← output weight + 1 ffn_down_exps

If you see any IQ1_S / IQ2_S / IQ3_S entries and you're on sm_120, you're hitting the bug whenever those tensors are routed to GPU.

Fix (workaround patch — 13 lines)

The cleanest end-user fix is to refuse these types in the CUDA backend on sm_120 so the ggml scheduler routes them to CPU. Performance impact is small (matches -ot ffn_.*_exps=CPU performance, ~80 TPS on Qwen3.6-35B-A3B), correctness is restored, and other quants are unaffected.

--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -4965,17 +4965,30 @@
                     case GGML_TYPE_Q6_K:
                     case GGML_TYPE_Q8_K:
                     case GGML_TYPE_IQ1_M:
-                    case GGML_TYPE_IQ1_S:
-                    case GGML_TYPE_IQ2_S:
                     case GGML_TYPE_IQ2_XS:
                     case GGML_TYPE_IQ2_XXS:
-                    case GGML_TYPE_IQ3_S:
                     case GGML_TYPE_IQ3_XXS:
                     case GGML_TYPE_IQ4_NL:
                     case GGML_TYPE_IQ4_XS:
                     case GGML_TYPE_BF16:
                         return true;
+                    case GGML_TYPE_IQ1_S:
+                    case GGML_TYPE_IQ2_S:
+                    case GGML_TYPE_IQ3_S:
+                    {
+                        // Workaround: IQ1_S / IQ2_S / IQ3_S MUL_MAT and MUL_MAT_ID kernels
+                        // produce essentially-random output on Blackwell consumer (sm_120).
+                        // Confirmed via test-backend-ops with NMSE in the 0.4-1.0 range for both
+                        // MUL_MAT and MUL_MAT_ID. Bug affects both the MMQ tensor-core path and
+                        // the cuBLAS dequant path (FORCE_CUBLAS does not help). Until the
+                        // underlying kernel bug is fixed, refuse these types here so the
+                        // scheduler routes them to the CPU backend.
+                        const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
+                        if (cc == GGML_CUDA_CC_BLACKWELL) {
+                            return false;
+                        }
+                        return true;
+                    }
                     default:
                         return false;
                 }

After applying this patch and rebuilding:

  • test-backend-ops -b CUDA0 test: 758/758 pass on RTX 5070 Ti (was failing for IQ1_S/IQ2_S/IQ3_S MUL_MAT and MUL_MAT_ID).
  • Qwen3.6-35B-A3B-UD-IQ4_XS produces coherent output across all prompts tested.
  • VRAM usage drops naturally because IQ3_S expert tensors run on CPU instead of GPU. Performance stays comparable to running with -ot 'ffn_.*_exps=CPU'.

Better long-term fix (not yet done)

This is a workaround, not a root cause fix. The actual kernel bug for IQ1_S/IQ2_S/IQ3_S on sm_120 lives somewhere in ggml/src/ggml-cuda/mmq.cuh (load_tiles_iq[1-3]_s) and/or ggml/src/ggml-cuda/convert.cu (dequantize_block_iq[1-3]_s) — both produce identical garbage in our tests, suggesting a shared root cause. I didn't dig deeper. Likely candidates: (a) a Blackwell compute_120a codegen issue specific to the bit-twiddling these "_s" variants do, (b) a shared-memory layout bug only exposed at higher SM counts. The plain _xxs / _xs / _nl variants in the same code paths work correctly, so it's localized to whatever's specific to the "_s" load_tiles functions.

If you can isolate the kernel-level cause, please replace this workaround with a real fix.

Note on issue #21371

Issue #21371 reports IQ-quant gibberish on a different Blackwell consumer card (RTX 5080) with gemma-4. The symptoms match this second bug. The workaround above may resolve it, depending on which IQ types the model actually uses (audit with the gguf-py snippet above).

Repro for the MoE case:

./build/bin/llama-server \
  --model /path/to/Qwen3-30B-A3B-Instruct-2507-Q3_K_M.gguf \
  --host 127.0.0.1 --port 8772 \
  -ngl 999 \
  -ot 'blk\.[0-9]+\.ffn_.*_exps\.weight=CPU' \
  --ctx-size 4096 --jinja &

curl -s http://127.0.0.1:8772/v1/chat/completions \
  -d '{"model":"x","messages":[{"role":"user","content":"hi"}],"max_tokens":4}'

Workarounds without patching llama.cpp

If you can't apply the patch, force the binary to load the right cudart:

  1. Remove the conflicting distro package (do this only if nothing else needs it):

    sudo apt remove --purge nvidia-cuda-toolkit

    ldd build/bin/libggml-cuda.so | grep cudart should now show libcudart.so.13.

  2. Configure CMake to prefer the modern cudart:

    cmake -B build -DCUDAToolkit_ROOT=/usr/local/cuda-13.2 ... # (and other flags)

    Verify the resulting libggml-cuda.so links against the 13.x cudart with ldd.

  3. Last resort — LD_PRELOAD can in theory redirect dlopen, but on this setup the binary has DT_NEEDED libcudart.so.12 baked in, so LD_PRELOAD=/usr/local/cuda-13/lib64/libcudart.so.13 does not help. The build-time fix is the only reliable route without patching the source.

Why MoE shows the bug fastest

Once smpbo is corrupted to 1, every subsequent cudaFuncSetAttribute(... MaxDynamicSharedMemorySize, 1) is a no-op until the kernel actually requests opt-in shared memory at launch. The kernel that fails first is whichever softmax the model hits first that requests >48 KB of dynamic shared memory, OR — more commonly — any softmax that requests dynamic shared memory at all when the kernel's runtime cap was set to 1 byte. MoE graphs include many small softmaxes (router, attention, etc.) that exercise this path on token zero.

Upstream

I haven't filed a PR — llama.cpp's CONTRIBUTING.md restricts AI-assisted contributions and I don't want to go over my head on it. If you're a maintainer or regular contributor and want to land this, please feel free to take the patch and ship it. The 12 lines are mechanical; the meaningful work is the diagnosis above.

Credit

If this saved you time, a polite linkback to this gist when discussing the fix would be appreciated. Not required.

Acknowledgements

Investigation and patch authored with AI assistance (Claude Code) under direct supervision: hypothesis generation, diagnostic instrumentation, evidence collection, and the final 12-line change were all reviewed before being run on hardware. The bug exists in upstream code; this gist describes how it manifests on Blackwell consumer GPUs and how to fix it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment