Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Created February 19, 2025 20:33
Show Gist options
  • Save AmosLewis/37d57aa044d727b56c442f11ed0ca02e to your computer and use it in GitHub Desktop.
Save AmosLewis/37d57aa044d727b56c442f11ed0ca02e to your computer and use it in GitHub Desktop.
(.venv) ➜ 128 python3 -m sharktank.examples.export_paged_llm_v1 --irpa-file=/sharedfile/attn/fp8_attn.irpa \
--output-mlir=/sharedfile/attn/128/fp8_attn.mlir \
--output-config=/sharedfile/attn/128/config_attn.json \
--bs=4 --attention-kernel sharktank \
--attention-dtype=float8_e4m3fnuz --activation-dtype=bfloat16 --use-attention-mask --use-hf
/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/iree/turbine/aot/params.py:163: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)
return torch.from_numpy(wrapper)
Exporting prefill_bs4
attention dtype
torch.float8_e4m3fnuz
attention dtype
torch.float8_e4m3fnuz
Exporting decode_bs4
GENERATED!
Exporting
Saving to '/sharedfile/attn/128/fp8_attn.mlir'
(.venv) ➜ 128 /home/chi/src/iree-build/tools/iree-compile \
/sharedfile/attn/128/fp8_attn.mlir \
--iree-hip-target=gfx942 \
-o=/sharedfile/attn/128/fp8_attn.vmfb \
--iree-hal-target-device=hip \
--iree-dispatch-creation-enable-aggressive-fusion=true \
--iree-global-opt-propagate-transposes=true \
--iree-opt-aggressively-propagate-transposes=true \
--iree-opt-data-tiling=false \
--iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-preprocessing-generalize-linalg-matmul-experimental))' \
--iree-hal-indirect-command-buffers=true \
--iree-stream-resource-memory-model=discrete \
--iree-hal-memoization=true \
--iree-opt-strip-assertions
/sharedfile/attn/128/fp8_attn.mlir:29732:13: error: 'util.call' op function type mismatch; expected '(tensor<4x32x?x128xf8E4M3FNUZ>, tensor<4x32x?x128xf8E4M3FNUZ>, tensor<4x32x?x128xf8E4M3FNUZ>, tensor<f32>, tensor<?x?x?x?xf8E4M3FNUZ>) -> tensor<4x32x?x128xf32>' but callee is '(tensor<4x32x?x128xf8E4M3FNUZ>, tensor<4x32x?x128xf8E4M3FNUZ>, tensor<4x32x?x128xf8E4M3FNUZ>, tensor<f32>, tensor<?x?xf8E4M3FNUZ>) -> tensor<4x32x?x128xf32>'
%2032 = "util.call"(%2026, %2027, %2028, %2031, %2030) <{callee = @sharktank_masked_flash_attention_4_32_128_128_f8E4M3FNUZ_f32_f32}> : (tensor<4x32x?x128xf8E4M3FNUZ>, tensor<4x32x?x128xf8E4M3FNUZ>, tensor<4x32x?x128xf8E4M3FNUZ>, tensor<f32>, tensor<?x?x?x?xf8E4M3FNUZ>) -> tensor<4x32x?x128xf32>
^
/sharedfile/attn/128/fp8_attn.mlir:29732:13: note: see current operation: %1905 = "util.call"(%1899, %1900, %1901, %1904, %1903) <{callee = @sharktank_masked_flash_attention_4_32_128_128_f8E4M3FNUZ_f32_f32}> : (tensor<4x32x?x128xf8E4M3FNUZ>, tensor<4x32x?x128xf8E4M3FNUZ>, tensor<4x32x?x128xf8E4M3FNUZ>, tensor<f32>, tensor<?x?x?x?xf8E4M3FNUZ>) -> tensor<4x32x?x128xf32>
@AmosLewis
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment