Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Last active February 5, 2025 17:56
Show Gist options
  • Save AmosLewis/43b03ac241a9b904ec664f37bbfd7f98 to your computer and use it in GitHub Desktop.
Save AmosLewis/43b03ac241a9b904ec664f37bbfd7f98 to your computer and use it in GitHub Desktop.
python3 -m sharktank.examples.export_paged_llm_v1 --irpa-file=/home/chi/src/test/llama/dan/fp8_attn.irpa \
--output-mlir=/home/chi/src/test/llama/dan/f8_attn_chi.mlir \
--output-config=/home/chi/src/test/llama/dan/config_attn_chi.json \
--bs=1 --attention-kernel sharktank \
--attention-dtype=float8_e4m3fnuz --activation-dtype=bfloat16 --use-hf
/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/iree/turbine/aot/params.py:163: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:203.)
  return torch.from_numpy(wrapper)
Exporting prefill_bs1
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/chi/src/shark-ai/sharktank/sharktank/examples/export_paged_llm_v1.py", line 358, in <module>
    main()
  File "/home/chi/src/shark-ai/sharktank/sharktank/examples/export_paged_llm_v1.py", line 339, in main
    generate_batch_prefill(bs)
  File "/home/chi/src/shark-ai/sharktank/sharktank/examples/export_paged_llm_v1.py", line 205, in generate_batch_prefill
    @fxb.export_program(
     ^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/iree/turbine/aot/fx_programs.py", line 239, in export_program
    program = torch.export.export(
              ^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/__init__.py", line 368, in export
    return _export(
           ^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1035, in wrapper
    raise e
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1008, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/exported_program.py", line 128, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1970, in _export
    return _export_for_training(
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1035, in wrapper
    raise e
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1008, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/exported_program.py", line 128, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1834, in _export_for_training
    export_artifact = export_func(  # type: ignore[operator]
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1772, in _non_strict_export
    aten_export_artifact = _to_aten_func(  # type: ignore[operator]
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1564, in _export_to_aten_ir_make_fx
    gm, graph_signature = transform(_make_fx_helper)(
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1702, in _aot_export_non_strict
    gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1485, in _make_fx_helper
    gm = make_fx(
         ^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 2196, in wrapped
    return make_fx_tracer.trace(f, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 2134, in trace
    return self._trace_inner(f, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 2105, in _trace_inner
    t = dispatch_trace(
        ^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_compile.py", line 32, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 1138, in dispatch_trace
    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 1694, in trace
    res = super().trace(root, concrete_args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 843, in trace
    (self.create_arg(fn(*args)),),
                     ^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 1193, in wrapped
    out = f(*tensors)  # type:ignore[call-arg]
          ^^^^^^^^^^^
  File "<string>", line 1, in <lambda>
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1469, in wrapped_fn
    return tuple(flat_fn(*args))
                 ^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
    tree_out = fn(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 879, in functional_call
    out = mod(*args[params_len:], **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 821, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 1764, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 539, in call_module
    ret_val = forward(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 814, in forward
    return _orig_module_call(mod, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1689, in forward
    tree_out = mod(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 821, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 1764, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 539, in call_module
    ret_val = forward(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 814, in forward
    return _orig_module_call(mod, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/iree/turbine/aot/fx_programs.py", line 226, in new_forward
    return f(self.root, *forward_args, **forward_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/sharktank/sharktank/examples/export_paged_llm_v1.py", line 236, in _
    logits = model.prefill(
             ^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/sharktank/sharktank/models/llama/llama.py", line 150, in prefill
    h = block(
        ^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 821, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 1764, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 539, in call_module
    ret_val = forward(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 814, in forward
    return _orig_module_call(mod, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/sharktank/sharktank/models/llama/llama.py", line 286, in forward
    h = self.attn(
        ^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 821, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 1764, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 539, in call_module
    ret_val = forward(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/fx/_symbolic_trace.py", line 814, in forward
    return _orig_module_call(mod, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/sharktank/sharktank/layers/paged_llama_attention_block.py", line 215, in forward
    attn_output = kernels.flash_attention(xq, keys, values)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_ops.py", line 1123, in __call__
    return self._op(*args, **(kwargs or {}))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 1241, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_ops.py", line 1123, in __call__
    return self._op(*args, **(kwargs or {}))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/fx/experimental/proxy_tensor.py", line 1288, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_ops.py", line 1123, in __call__
    return self._op(*args, **(kwargs or {}))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_export/non_strict_utils.py", line 557, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_ops.py", line 1123, in __call__
    return self._op(*args, **(kwargs or {}))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: sharktank::flash_attention() is missing value for argument 'scale'. Declaration: sharktank::flash_attention(Tensor q, Tensor k, Tensor v, Tensor scale) -> Tensor
  
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment