Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Created March 21, 2025 20:18
Show Gist options
  • Save AmosLewis/252fdeaee1b279d6f92ce10ca8e581e7 to your computer and use it in GitHub Desktop.
Save AmosLewis/252fdeaee1b279d6f92ce10ca8e581e7 to your computer and use it in GitHub Desktop.
(.venv) ➜ shark-ai git:(cbd6b7a6) ✗ python3 -m sharktank.examples.export_paged_llm_v1 --irpa-file=/sharedfile/attn/fp8_attn.irpa \
--output-mlir=/sharedfile/attn/128/fp8_attn.mlir \
--output-config=/sharedfile/attn/128/config_attn.json \
--bs-prefill=4 --bs-decode=4 --attention-kernel sharktank \
--attention-dtype=float8_e4m3fnuz --activation-dtype=bfloat16 --use-attention-mask --use-hf --kv-cache-dtype=float8_e4m3fnuz
cat_default
conv2d_default
conv2d_default
einsum_2args
elementwise_unary
elementwise_binary
elementwise_ternary
embedding_lookup_default
embedding_lookup_Tensor_QuantizedTensor
equal_default
expand_default
flatten_default
gather_default
get_index_default
get_index_QuantizedTensor
gemm
group_norm_affine_default
index_copy__default
index_put__default
index_select_default
interpolate_default
layer_norm_default
layer_norm_default
layer_norm_default
linear_default
linear_default
matmul_default
scaled_dot_product_attention_torch
mean_default
module_register_buffer_default
repeat_default
reshape_default
rms_norm_default
rms_norm_Tensor_QuantizedTensor
permute
softmax_default
to_default
trace_tensor
transfer_to_logical_device_default
barrier_on_device_default
transpose_default
sharded_cat_unsharded
sharded_sum_unsharded
unflatten_default
unsqueeze_default
squeeze_default
topk_default
view_default
view_QuantizedTensor
view_as_complex_default
view_as_real_default
einsum_2args_QuantizedTensor
matmul_generic_tensor_block_scaled
matmul_generic_tensor_block_scaled_i4
matmul_generic_tensor_super_block_offset_scaled_4_6_i4
view_as_complex
view_as_real
all_gather_split
all_reduce_split_or_unreduced
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
index_copy__split_replicated_split
index_put__split
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
replicate_replicated
replicate_split
replicate_unreduced
replicate_unsharded
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
func_wrapper
flash_attention
qconv2d_tensor_scaled
qconv2d_tensor_scaled
qlinear_tensor_scaled
qlinear_tensor_scaled
linear_quantized_weight
linear_quantized_weight
['scaled_dot_product_attention_torch', 'func_wrapper', 'flash_attention']
masked_flash_attention
['scaled_dot_product_attention_torch', 'func_wrapper', 'flash_attention', 'masked_flash_attention']
/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/iree/turbine/aot/params.py:163: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)
return torch.from_numpy(wrapper)
Exporting prefill_bs4
/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_export/non_strict_utils.py:520: UserWarning: Tensor.T is deprecated on 0-D tensors. This function is the identity in these cases. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3691.)
return func(*args, **kwargs)
/home/chi/src/shark-ai/sharktank/sharktank/types/quantizers.py:69: UserWarning: Requantizing already quantized tensor PlanarQuantized(<unnamed>, [4, s0, 4096], layout=TensorScaledLayout(d([], dtype=torch.float32), qs([4, s0, 4096], dtype=torch.float32)) -> torch.bfloat16) to StaticScaledQuantizer(blk.0.attn_q.q_output, torch.Size([]), scale=(torch.Size([]), torch.float32) along None) offset=None -> dtype=torch.float8_e4m3fnuz)
warnings.warn(f"Requantizing already quantized tensor {t} to {self}")
/home/chi/src/shark-ai/sharktank/sharktank/types/quantizers.py:69: UserWarning: Requantizing already quantized tensor PlanarQuantized(<unnamed>, [4, s0, 1024], layout=TensorScaledLayout(d([], dtype=torch.float32), qs([4, s0, 1024], dtype=torch.float32)) -> torch.bfloat16) to StaticScaledQuantizer(blk.0.attn_k.q_output, torch.Size([]), scale=(torch.Size([]), torch.float32) along None) offset=None -> dtype=torch.float8_e4m3fnuz)
warnings.warn(f"Requantizing already quantized tensor {t} to {self}")
/home/chi/src/shark-ai/sharktank/sharktank/types/quantizers.py:69: UserWarning: Requantizing already quantized tensor PlanarQuantized(<unnamed>, [4, s0, 1024], layout=TensorScaledLayout(d([], dtype=torch.float32), qs([4, s0, 1024], dtype=torch.float32)) -> torch.bfloat16) to StaticScaledQuantizer(blk.0.attn_v.q_output, torch.Size([]), scale=(torch.Size([]), torch.float32) along None) offset=None -> dtype=torch.float8_e4m3fnuz)
warnings.warn(f"Requantizing already quantized tensor {t} to {self}")
Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/home/chi/src/shark-ai/sharktank/sharktank/examples/export_paged_llm_v1.py", line 371, in <module>
main()
File "/home/chi/src/shark-ai/sharktank/sharktank/examples/export_paged_llm_v1.py", line 351, in main
generate_batch_prefill(bs)
File "/home/chi/src/shark-ai/sharktank/sharktank/examples/export_paged_llm_v1.py", line 217, in generate_batch_prefill
@fxb.export_program(
^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/iree/turbine/aot/fx_programs.py", line 239, in export_program
program = torch.export.export(
^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/__init__.py", line 270, in export
return _export(
^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1017, in wrapper
raise e
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 990, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/exported_program.py", line 114, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1880, in _export
export_artifact = export_func( # type: ignore[operator]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1683, in _non_strict_export
aten_export_artifact = _to_aten_func( # type: ignore[operator]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 637, in _export_to_aten_ir
gm, graph_signature = transform(aot_export_module)(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1611, in _aot_export_non_strict
gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1246, in aot_export_module
fx_g, metadata, in_spec, out_spec = _aot_export_function(
^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 1480, in _aot_export_function
fx_g, meta = create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 522, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 623, in _create_aot_dispatcher_function
fw_metadata = run_functionalized_fw_and_collect_metadata(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 173, in inner
flat_f_outs = f(*flat_f_args)
^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py", line 182, in flat_fn
tree_out = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 863, in functional_call
out = mod(*args[params_len:], **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1598, in forward
tree_out = self._export_root(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/iree/turbine/aot/fx_programs.py", line 226, in new_forward
return f(self.root, *forward_args, **forward_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/sharktank/sharktank/examples/export_paged_llm_v1.py", line 248, in _
logits = model.prefill(
^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/sharktank/sharktank/models/llama/llama.py", line 146, in prefill
h = block(
^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/sharktank/sharktank/models/llama/llama.py", line 284, in forward
h = self.attn(
^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/sharktank/sharktank/layers/paged_llama_attention_block.py", line 127, in forward
xq = xq.view(bs, batch_seq_len, self.head_count, self.head_dim)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/sharktank/sharktank/types/tensors.py", line 420, in view
return view(self, shape)
^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/sharktank/sharktank/ops/_registry.py", line 199, in __call__
selected_override, *results = trampoline(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/chi/src/shark-ai/sharktank/sharktank/ops/signatures.py", line 1244, in _view_trampoline
d.fail(tensors)
File "/home/chi/src/shark-ai/sharktank/sharktank/ops/_registry.py", line 249, in fail
raise NotImplementedError(
NotImplementedError: Overridable operator sharktank.ops.signatures.view does not have an implementation for argument types: [<class 'sharktank.types.tensors.PlanarQuantizedTensor'>]
@AmosLewis
Copy link
Author

shark-ai commit: 46773898a7d1cf9dcc5f476112b84179d759cf1a
[sharktank] Refactor PagedKVCache to PagedAttention (https://github.com/nod-ai/shark-ai/pull/1098[)](https://github.com/nod-ai/shark-ai/commit/46773898a7d1cf9dcc5f476112b84179d759cf1a)

python3 -m sharktank.examples.export_paged_llm_v1 --irpa-file=/sharedfile/attn/fp8_attn.irpa \
--output-mlir=/sharedfile/attn/128/fp8_attn.mlir \
--output-config=/sharedfile/attn/128/config_attn.json \
--bs-prefill=4 --bs-decode=4 --attention-kernel sharktank \
--attention-dtype=float8_e4m3fnuz --activation-dtype=bfloat16 --use-attention-mask --use-hf --kv-cache-dtype=float8_e4m3fnuz
...
['scaled_dot_product_attention_torch', 'func_wrapper', 'flash_attention']
masked_flash_attention
['scaled_dot_product_attention_torch', 'func_wrapper', 'flash_attention', 'masked_flash_attention']
/home/chi/src/shark-ai/.venv/lib/python3.11/site-packages/iree/turbine/aot/params.py:163: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)
  return torch.from_numpy(wrapper)
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/chi/src/shark-ai/sharktank/sharktank/examples/export_paged_llm_v1.py", line 371, in <module>
    main()
  File "/home/chi/src/shark-ai/sharktank/sharktank/examples/export_paged_llm_v1.py", line 106, in main
    model = PagedLlamaModelV1(dataset.root_theta, llama_config)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chi/src/shark-ai/sharktank/sharktank/models/llama/llama.py", line 104, in __init__
    [
  File "/home/chi/src/shark-ai/sharktank/sharktank/models/llama/llama.py", line 105, in <listcomp>
    AttentionFFNBlock(
  File "/home/chi/src/shark-ai/sharktank/sharktank/models/llama/llama.py", line 245, in __init__
    PagedLlamaAttentionBlock(
  File "/home/chi/src/shark-ai/sharktank/sharktank/layers/paged_llama_attention_block.py", line 87, in __init__
    self.probs_quantizer = StaticScaledQuantizer(
                           ^^^^^^^^^^^^^^^^^^^^^
NameError: name 'StaticScaledQuantizer' is not defined

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment