Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Last active February 26, 2025 18:26
Show Gist options
  • Save AmosLewis/a4f72fdf803fddf935bcda83815b63f0 to your computer and use it in GitHub Desktop.
Save AmosLewis/a4f72fdf803fddf935bcda83815b63f0 to your computer and use it in GitHub Desktop.
(.venv) ➜ 32 python3 -m sharktank.examples.export_paged_llm_v1 --irpa-file=/sharedfile/llama3_8b_fp8.irpa \
--output-mlir=/sharedfile/32/fp8_attn.mlir \
--output-config=/sharedfile/32/config_attn.json \
--bs=1 --attention-kernel sharktank \
--attention-dtype=bfloat16 \
--activation-dtype=bfloat16 \
--kv-cache-dtype=float8_e4m3fnuz \
--use-hf \
--use-attention-mask
/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/iree/turbine/aot/params.py:163: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)
return torch.from_numpy(wrapper)
Exporting prefill_bs1
attention dtype
torch.bfloat16
Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/home/chi/src/shark-ai/sharktank/sharktank/examples/export_paged_llm_v1.py", line 359, in <module>
main()
File "/home/chi/src/shark-ai/sharktank/sharktank/examples/export_paged_llm_v1.py", line 340, in main
generate_batch_prefill(bs)
File "/home/chi/src/shark-ai/sharktank/sharktank/examples/export_paged_llm_v1.py", line 206, in generate_batch_prefill
@fxb.export_program(
^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/iree/turbine/aot/fx_programs.py", line 239, in export_program
program = torch.export.export(
^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/__init__.py", line 270, in export
return _export(
^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1017, in wrapper
raise e
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 990, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/exported_program.py", line 114, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1880, in _export
export_artifact = export_func( # type: ignore[operator]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1683, in _non_strict_export
aten_export_artifact = _to_aten_func( # type: ignore[operator]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 637, in _export_to_aten_ir
gm, graph_signature = transform(aot_export_module)(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1611, in _aot_export_non_strict
gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1246, in aot_export_module
fx_g, metadata, in_spec, out_spec = _aot_export_function(
^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1480, in _aot_export_function
fx_g, meta = create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 522, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 623, in _create_aot_dispatcher_function
fw_metadata = run_functionalized_fw_and_collect_metadata(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 173, in inner
flat_f_outs = f(*flat_f_args)
^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py", line 182, in flat_fn
tree_out = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 863, in functional_call
out = mod(*args[params_len:], **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1598, in forward
tree_out = self._export_root(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/iree/turbine/aot/fx_programs.py", line 226, in new_forward
return f(self.root, *forward_args, **forward_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/sharktank/sharktank/examples/export_paged_llm_v1.py", line 237, in _
logits = model.prefill(
^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/sharktank/sharktank/models/llama/llama.py", line 151, in prefill
h = block(
^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/sharktank/sharktank/models/llama/llama.py", line 289, in forward
h = self.attn(
^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/sharktank/sharktank/layers/paged_llama_attention_block.py", line 215, in forward
assert self.attention_scale is not None
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError
(.venv) ➜ 32 python3 -m sharktank.examples.export_paged_llm_v1 --irpa-file=/sharedfile/llama3_8b_fp8.irpa \
--output-mlir=/sharedfile/32/fp8_attn.mlir \
--output-config=/sharedfile/32/config_attn.json \
--bs=1 --attention-kernel sharktank \
--attention-dtype=float8_e4m3fnuz \
--activation-dtype=bfloat16 \
--kv-cache-dtype=float8_e4m3fnuz \
--use-hf \
--use-attention-mask
/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/iree/turbine/aot/params.py:163: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)
return torch.from_numpy(wrapper)
Exporting prefill_bs1
attention dtype
torch.float8_e4m3fnuz
Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/home/chi/src/shark-ai/sharktank/sharktank/examples/export_paged_llm_v1.py", line 359, in <module>
main()
File "/home/chi/src/shark-ai/sharktank/sharktank/examples/export_paged_llm_v1.py", line 340, in main
generate_batch_prefill(bs)
File "/home/chi/src/shark-ai/sharktank/sharktank/examples/export_paged_llm_v1.py", line 206, in generate_batch_prefill
@fxb.export_program(
^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/iree/turbine/aot/fx_programs.py", line 239, in export_program
program = torch.export.export(
^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/__init__.py", line 270, in export
return _export(
^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1017, in wrapper
raise e
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 990, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/exported_program.py", line 114, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1880, in _export
export_artifact = export_func( # type: ignore[operator]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1683, in _non_strict_export
aten_export_artifact = _to_aten_func( # type: ignore[operator]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 637, in _export_to_aten_ir
gm, graph_signature = transform(aot_export_module)(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1611, in _aot_export_non_strict
gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1246, in aot_export_module
fx_g, metadata, in_spec, out_spec = _aot_export_function(
^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1480, in _aot_export_function
fx_g, meta = create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 522, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 623, in _create_aot_dispatcher_function
fw_metadata = run_functionalized_fw_and_collect_metadata(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 173, in inner
flat_f_outs = f(*flat_f_args)
^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py", line 182, in flat_fn
tree_out = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 863, in functional_call
out = mod(*args[params_len:], **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1598, in forward
tree_out = self._export_root(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/iree/turbine/aot/fx_programs.py", line 226, in new_forward
return f(self.root, *forward_args, **forward_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/sharktank/sharktank/examples/export_paged_llm_v1.py", line 237, in _
logits = model.prefill(
^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/sharktank/sharktank/models/llama/llama.py", line 151, in prefill
h = block(
^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/sharktank/sharktank/models/llama/llama.py", line 289, in forward
h = self.attn(
^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/sharktank/sharktank/layers/paged_llama_attention_block.py", line 215, in forward
assert self.attention_scale is not None
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment