Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Created September 4, 2025 00:26
Show Gist options
  • Save AmosLewis/345e0e589b3911bb258cc4f4a1917f9b to your computer and use it in GitHub Desktop.
Save AmosLewis/345e0e589b3911bb258cc4f4a1917f9b to your computer and use it in GitHub Desktop.
((.venv12) ) ➜ shark-ai git:(2bb2d590b) ✗ /sharedfile/f16/export_run_f16_8b_tp1.sh
No flag provided. Using default iree_day 0828.
No flag provided. Using default shark_day 0828_2bb_kv8.
/sharedfile/f16/128/8b/out/8b_fp16_iree0828.shark0828_2bb_kv8.mlir
/sharedfile/f16/128/8b/out/8b_fp16_iree0828.shark0828_2bb_kv8.json
/sharedfile/f16/128/8b/out/8b_fp16_iree0828.shark0828_2bb_kv8.vmfb
/sharedfile/f16/128/8b/out/8b_fp16_iree0828.shark0828_2bb_kv8.prefill.txt
File created: /sharedfile/f16/128/8b/out/8b_fp16_iree0828.shark0828_2bb_kv8.prefill.txt
/sharedfile/f16/128/8b/out/8b_fp16_iree0828.shark0828_2bb_kv8.decode.txt
File created: /sharedfile/f16/128/8b/out/8b_fp16_iree0828.shark0828_2bb_kv8.decode.txt
python3 -m sharktank.examples.export_paged_llm_v1 --irpa-file=/shark-dev/llama3.1/8b/fp16/weight/8b_fp16.irpa --output-mlir=/sharedfile/f16/128/8b/out/8b_fp16_iree0828.shark0828_2bb_kv8.mlir --output-config=/sharedfile/f16/128/8b/out/8b_fp16_iree0828.shark0828_2bb_kv8.json --bs-prefill=4 --bs-decode=4 --block-seq-stride=32 --attention-dtype=float16 --activation-dtype=float16 --attention-kernel=torch --device-block-count=8192 --kv-cache-dtype=float8_e4m3fn
/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/iree/turbine/aot/params.py:163: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)
return torch.from_numpy(wrapper)
Exporting prefill_bs4
Exporting decode_bs4
Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/home/chiliu12/src/shark-ai/sharktank/sharktank/examples/export_paged_llm_v1.py", line 279, in <module>
main()
File "/home/chiliu12/src/shark-ai/sharktank/sharktank/examples/export_paged_llm_v1.py", line 263, in main
output_export, output_config = export_llm_v1(
^^^^^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/sharktank/sharktank/examples/export_paged_llm_v1.py", line 183, in export_llm_v1
generate_batch_decode(bs)
File "/home/chiliu12/src/shark-ai/sharktank/sharktank/examples/export_paged_llm_v1.py", line 149, in generate_batch_decode
@fxb.export_program(
^^^^^^^^^^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/iree/turbine/aot/fx_programs.py", line 239, in export_program
program = torch.export.export(
^^^^^^^^^^^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/export/__init__.py", line 270, in export
return _export(
^^^^^^^^
File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/export/_trace.py", line 1017, in wrapper
raise e
File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/export/_trace.py", line 990, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/export/exported_program.py", line 114, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/export/_trace.py", line 1880, in _export
export_artifact = export_func( # type: ignore[operator]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/export/_trace.py", line 1683, in _non_strict_export
aten_export_artifact = _to_aten_func( # type: ignore[operator]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/export/_trace.py", line 637, in _export_to_aten_ir
gm, graph_signature = transform(aot_export_module)(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/export/_trace.py", line 1611, in _aot_export_non_strict
gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1246, in aot_export_module
fx_g, metadata, in_spec, out_spec = _aot_export_function(
^^^^^^^^^^^^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1480, in _aot_export_function
fx_g, meta = create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 522, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 623, in _create_aot_dispatcher_function
fw_metadata = run_functionalized_fw_and_collect_metadata(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 173, in inner
flat_f_outs = f(*flat_f_args)
^^^^^^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 182, in flat_fn
tree_out = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 863, in functional_call
out = mod(*args[params_len:], **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/export/_trace.py", line 1598, in forward
tree_out = self._export_root(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/iree/turbine/aot/fx_programs.py", line 226, in new_forward
return f(self.root, *forward_args, **forward_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/sharktank/sharktank/examples/export_paged_llm_v1.py", line 170, in _
return model.decode(
^^^^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/sharktank/sharktank/models/llm/export.py", line 124, in decode
logits = self.model.decode(
^^^^^^^^^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/sharktank/sharktank/models/llm/llm.py", line 233, in decode
h = block(
^^^^^^
File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/sharktank/sharktank/models/llm/llm.py", line 412, in forward
h = self.attn(
^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/sharktank/sharktank/layers/paged_llama_attention_block.py", line 293, in forward
attn_output = self.paged_attention.forward_decode(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/sharktank/sharktank/layers/paged_attention.py", line 580, in forward_decode
return self.paged_attention(
^^^^^^^^^^^^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/sharktank/sharktank/layers/paged_attention.py", line 634, in paged_attention
return self.attention(
^^^^^^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/sharktank/sharktank/layers/paged_attention.py", line 533, in attention
return ops.scaled_dot_product_attention(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/sharktank/sharktank/ops/_registry.py", line 260, in __call__
selected_override, *results = trampoline(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/sharktank/sharktank/ops/_registry.py", line 554, in trampoline
result = override(*bound_args.args, **call_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/sharktank/sharktank/ops/attention_impls.py", line 222, in scaled_dot_product_attention_torch
return torch.nn.functional.scaled_dot_product_attention(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/_export/non_strict_utils.py", line 520, in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected query, key, and value to have the same dtype, but got query.dtype: c10::Half key.dtype: c10::Float8_e4m3fn and value.dtype: c10::Float8_e4m3fn instead.
@AmosLewis
Copy link
Author

((.venv12) ) ➜  shark-ai git:(2bb2d590b) ✗ /sharedfile/f16/export_run_f16_8b_tp1_kv8.sh
No flag provided. Using default iree_day 0828.
No flag provided. Using default shark_day 0828_2bb_kv8.
/sharedfile/f16/128/8b/out/8b_fp16_iree0828.shark0828_2bb_kv8.mlir
/sharedfile/f16/128/8b/out/8b_fp16_iree0828.shark0828_2bb_kv8.json
/sharedfile/f16/128/8b/out/8b_fp16_iree0828.shark0828_2bb_kv8.vmfb
/sharedfile/f16/128/8b/out/8b_fp16_iree0828.shark0828_2bb_kv8.prefill.txt
File already exists: /sharedfile/f16/128/8b/out/8b_fp16_iree0828.shark0828_2bb_kv8.prefill.txt
/sharedfile/f16/128/8b/out/8b_fp16_iree0828.shark0828_2bb_kv8.decode.txt
File already exists: /sharedfile/f16/128/8b/out/8b_fp16_iree0828.shark0828_2bb_kv8.decode.txt
python3 -m sharktank.examples.export_paged_llm_v1   --irpa-file=/shark-dev/llama3.1/8b/fp16/weight/8b_fp16.irpa   --output-mlir=/sharedfile/f16/128/8b/out/8b_fp16_iree0828.shark0828_2bb_kv8.mlir   --output-config=/sharedfile/f16/128/8b/out/8b_fp16_iree0828.shark0828_2bb_kv8.json   --bs-prefill=4 --bs-decode=4   --block-seq-stride=32   --attention-dtype=float16   --activation-dtype=float16   --attention-kernel=torch   --device-block-count=8192   --use-hf   --kv-cache-dtype=float8_e4m3fn
/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/iree/turbine/aot/params.py:163: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)
  return torch.from_numpy(wrapper)
Exporting prefill_bs4
Exporting decode_bs4
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/chiliu12/src/shark-ai/sharktank/sharktank/examples/export_paged_llm_v1.py", line 279, in <module>
    main()
  File "/home/chiliu12/src/shark-ai/sharktank/sharktank/examples/export_paged_llm_v1.py", line 263, in main
    output_export, output_config = export_llm_v1(
                                   ^^^^^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/sharktank/sharktank/examples/export_paged_llm_v1.py", line 183, in export_llm_v1
    generate_batch_decode(bs)
  File "/home/chiliu12/src/shark-ai/sharktank/sharktank/examples/export_paged_llm_v1.py", line 149, in generate_batch_decode
    @fxb.export_program(
     ^^^^^^^^^^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/iree/turbine/aot/fx_programs.py", line 239, in export_program
    program = torch.export.export(
              ^^^^^^^^^^^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/export/__init__.py", line 270, in export
    return _export(
           ^^^^^^^^
  File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/export/_trace.py", line 1017, in wrapper
    raise e
  File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/export/_trace.py", line 990, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/export/exported_program.py", line 114, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/export/_trace.py", line 1880, in _export
    export_artifact = export_func(  # type: ignore[operator]
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/export/_trace.py", line 1683, in _non_strict_export
    aten_export_artifact = _to_aten_func(  # type: ignore[operator]
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/export/_trace.py", line 637, in _export_to_aten_ir
    gm, graph_signature = transform(aot_export_module)(
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/export/_trace.py", line 1611, in _aot_export_non_strict
    gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1246, in aot_export_module
    fx_g, metadata, in_spec, out_spec = _aot_export_function(
                                        ^^^^^^^^^^^^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1480, in _aot_export_function
    fx_g, meta = create_aot_dispatcher_function(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 522, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 623, in _create_aot_dispatcher_function
    fw_metadata = run_functionalized_fw_and_collect_metadata(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 173, in inner
    flat_f_outs = f(*flat_f_args)
                  ^^^^^^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 182, in flat_fn
    tree_out = fn(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 863, in functional_call
    out = mod(*args[params_len:], **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/export/_trace.py", line 1598, in forward
    tree_out = self._export_root(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/iree/turbine/aot/fx_programs.py", line 226, in new_forward
    return f(self.root, *forward_args, **forward_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/sharktank/sharktank/examples/export_paged_llm_v1.py", line 170, in _
    return model.decode(
           ^^^^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/sharktank/sharktank/models/llm/export.py", line 124, in decode
    logits = self.model.decode(
             ^^^^^^^^^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/sharktank/sharktank/models/llm/llm.py", line 233, in decode
    h = block(
        ^^^^^^
  File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/sharktank/sharktank/models/llm/llm.py", line 412, in forward
    h = self.attn(
        ^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/sharktank/sharktank/layers/paged_llama_attention_block.py", line 293, in forward
    attn_output = self.paged_attention.forward_decode(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/sharktank/sharktank/layers/paged_attention.py", line 580, in forward_decode
    return self.paged_attention(
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/sharktank/sharktank/layers/paged_attention.py", line 634, in paged_attention
    return self.attention(
           ^^^^^^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/sharktank/sharktank/layers/paged_attention.py", line 533, in attention
    return ops.scaled_dot_product_attention(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/sharktank/sharktank/ops/_registry.py", line 260, in __call__
    selected_override, *results = trampoline(self, *args, **kwargs)
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/sharktank/sharktank/ops/_registry.py", line 554, in trampoline
    result = override(*bound_args.args, **call_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/sharktank/sharktank/ops/attention_impls.py", line 222, in scaled_dot_product_attention_torch
    return torch.nn.functional.scaled_dot_product_attention(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chiliu12/src/shark-ai/.venv12/lib/python3.12/site-packages/torch/_export/non_strict_utils.py", line 520, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected query, key, and value to have the same dtype, but got query.dtype: c10::Half key.dtype: c10::Float8_e4m3fn and value.dtype: c10::Float8_e4m3fn instead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment