Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Created February 24, 2025 17:14
Show Gist options
  • Save AmosLewis/85f4437a179465adf8d0595b61da5089 to your computer and use it in GitHub Desktop.
Save AmosLewis/85f4437a179465adf8d0595b61da5089 to your computer and use it in GitHub Desktop.
python3 -m sharktank.examples.export_paged_llm_v1 --irpa-file=/sharedfile/llama3_8b_fp8.irpa \
--output-mlir=/sharedfile/32/fp8_32.mlir \
--output-config=/sharedfile/32/config_32.json \
--bs=1 --attention-kernel torch \
--attention-dtype=float8_e4m3fnuz --activation-dtype=bfloat16 \
--use-hf \
--kv-cache-dtype=float8_e4m3fnuz
/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/iree/turbine/aot/params.py:163: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)
return torch.from_numpy(wrapper)
Exporting prefill_bs1
/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_export/non_strict_utils.py:520: UserWarning: Tensor.T is deprecated on 0-D tensors. This function is the identity in these cases. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3691.)
return func(*args, **kwargs)
Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/home/chi/src/shark-ai/sharktank/sharktank/examples/export_paged_llm_v1.py", line 360, in <module>
main()
File "/home/chi/src/shark-ai/sharktank/sharktank/examples/export_paged_llm_v1.py", line 341, in main
generate_batch_prefill(bs)
File "/home/chi/src/shark-ai/sharktank/sharktank/examples/export_paged_llm_v1.py", line 207, in generate_batch_prefill
@fxb.export_program(
^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/iree/turbine/aot/fx_programs.py", line 239, in export_program
program = torch.export.export(
^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/__init__.py", line 270, in export
return _export(
^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1017, in wrapper
raise e
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 990, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/exported_program.py", line 114, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1880, in _export
export_artifact = export_func( # type: ignore[operator]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1683, in _non_strict_export
aten_export_artifact = _to_aten_func( # type: ignore[operator]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 637, in _export_to_aten_ir
gm, graph_signature = transform(aot_export_module)(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1611, in _aot_export_non_strict
gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1246, in aot_export_module
fx_g, metadata, in_spec, out_spec = _aot_export_function(
^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1480, in _aot_export_function
fx_g, meta = create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 522, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 623, in _create_aot_dispatcher_function
fw_metadata = run_functionalized_fw_and_collect_metadata(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 173, in inner
flat_f_outs = f(*flat_f_args)
^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py", line 182, in flat_fn
tree_out = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 863, in functional_call
out = mod(*args[params_len:], **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1598, in forward
tree_out = self._export_root(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/iree/turbine/aot/fx_programs.py", line 226, in new_forward
return f(self.root, *forward_args, **forward_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/sharktank/sharktank/examples/export_paged_llm_v1.py", line 238, in _
logits = model.prefill(
^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/sharktank/sharktank/models/llama/llama.py", line 151, in prefill
h = block(
^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/sharktank/sharktank/models/llama/llama.py", line 289, in forward
h = self.attn(
^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/sharktank/sharktank/layers/paged_llama_attention_block.py", line 206, in forward
attn_output = ops.scaled_dot_product_attention(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/sharktank/sharktank/ops/_registry.py", line 199, in __call__
selected_override, *results = trampoline(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/sharktank/sharktank/ops/signatures.py", line 863, in _scaled_dot_product_attention
result = override(q, k, v, a, is_causal=is_causal, scale=scale)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/sharktank/sharktank/ops/default_impls.py", line 372, in scaled_dot_product_attention_torch
return torch.nn.functional.scaled_dot_product_attention(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_export/non_strict_utils.py", line 520, in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected query, key, and value to have the same dtype, but got query.dtype: c10::BFloat16 key.dtype: c10::Float8_e4m3fnuz and value.dtype: c10::Float8_e4m3fnuz instead.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment