Skip to content

Instantly share code, notes, and snippets.

# wget https://sharkpublic.blob.core.windows.net/sharkpublic/halo-models/llm-dev/llama3_8b/prefill_args_bs4_128_stride_32/cs_f16.npy
# wget https://sharkpublic.blob.core.windows.net/sharkpublic/halo-models/llm-dev/llama3_8b/prefill_args_bs4_128_stride_32/seq_block_ids.npy
# wget https://sharkpublic.blob.core.windows.net/sharkpublic/halo-models/llm-dev/llama3_8b/prefill_args_bs4_128_stride_32/seq_lens.npy
# wget https://sharkpublic.blob.core.windows.net/sharkpublic/halo-models/llm-dev/llama3_8b/prefill_args_bs4_128_stride_32/tokens.npy
import numpy as np
import torch
prefills = ['cs_f16','seq_block_ids','seq_lens','tokens']
for prefill in prefills:
# Download the f32 npy, then use this script to cast 32 to 16
# wget https://gist.github.com/aviator19941/380acabc77aeb4749fac14262e17db69
# wget https://sharkpublic.blob.core.windows.net/sharkpublic/halo-models/llm-dev/llama3_8b/prefill_args_bs4_128_stride_32/cs_f16.npy
# wget https://sharkpublic.blob.core.windows.net/sharkpublic/halo-models/llm-dev/llama3_8b/prefill_args_bs4_128_stride_32/seq_block_ids.npy
# wget https://sharkpublic.blob.core.windows.net/sharkpublic/halo-models/llm-dev/llama3_8b/prefill_args_bs4_128_stride_32/seq_lens.npy
# wget https://sharkpublic.blob.core.windows.net/sharkpublic/halo-models/llm-dev/llama3_8b/prefill_args_bs4_128_stride_32/tokens.npy
# pip install numpy==1.26
# pip install bfloat16
import numpy as np
from bfloat16 import bfloat16
This file has been truncated, but you can view the full file.
module @module {
util.global private @__auto.token_embd.weight = #stream.parameter.named<"model"::"token_embd.weight"> : tensor<128256x4096xbf16>
util.global private @__auto.blk.0.attn_norm.weight = #stream.parameter.named<"model"::"blk.0.attn_norm.weight"> : tensor<4096xbf16>
util.global private @"__auto.blk.0.attn_q.q_input:rscale" = #stream.parameter.named<"model"::"blk.0.attn_q.q_input:rscale"> : tensor<f32>
util.global private @"__auto.blk.0.attn_q.weight:qs" = #stream.parameter.named<"model"::"blk.0.attn_q.weight:qs"> : tensor<4096x4096xf8E4M3FNUZ>
util.global private @"__auto.blk.0.attn_k.q_input:rscale" = #stream.parameter.named<"model"::"blk.0.attn_k.q_input:rscale"> : tensor<f32>
util.global private @"__auto.blk.0.attn_k.weight:qs" = #stream.parameter.named<"model"::"blk.0.attn_k.weight:qs"> : tensor<1024x4096xf8E4M3FNUZ>
util.global private @"__auto.blk.0.attn_v.q_input:rscale" = #stream.parameter.named<"model"::"blk.0.attn_v.q_input:rscale"> : tensor<f32>
util.global private @"__au
(.venv) ➜ llama /home/chi/src/iree-build-trace/tools/iree-compile \
fp8.mlir \
--iree-hip-target=gfx942 \
-o=fp8_tracy.vmfb \
--iree-hal-target-device=hip \
--iree-dispatch-creation-enable-aggressive-fusion=true \
--iree-global-opt-propagate-transposes=true \
--iree-opt-aggressively-propagate-transposes=true \
--iree-opt-data-tiling=false \
--iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-preprocessing-generalize-linalg-matmul-experimental))' \
iree-compile --help
OVERVIEW: IREE compilation driver
USAGE: iree-compile [options] <input file or '-' for stdin>
OPTIONS:
CUDA HAL Target:
--iree-cuda-target=<string> - CUDA target as expected by LLVM NVPTX backend; e.g., 'sm_80'/'sm_90' for targeting Ampere/Hopper GPUs. Additionally this also supports architecture code names like 'turing'/'ampere' or some product names like 'a100'/'rtx3090ti' for a better experience. See https://iree.dev/guides/deployment-configurations/gpu-cuda for more details.
torch-mlir-opt -pass-pipeline='builtin.module(func.func(torch-match-quantized-custom-ops), torchdynamo-export-to-torch-backend-pipeline{extra-library=})' /tmp/UnnammedModule.mlir --debug
Args: torch-mlir-opt -pass-pipeline=builtin.module(func.func(torch-match-quantized-custom-ops), torchdynamo-export-to-torch-backend-pipeline{extra-library=}) /tmp/UnnammedModule.mlir --debug
Load new dialect in Context builtin
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ShapedType)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::MemRefLayoutAttrInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::TypedAttr)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ElementsAttr)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::DistinctAttr)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::BytecodeOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::SymbolOpInterface)
class NonzeroDecomposeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([-1], torch.bool, True),
]
import torch
def nonzero(t):
print("t: ", t) # tensor([0, 0, 1, 1, 0, 0])
# Flatten the input tensor
t_flat = t.flatten() # torch.flatten(t, 0, 0)
print(
"t_flat: ", t_flat
) # tensortensor([0, 0, 1, 1, 0, 0]), torch.Size([6]), #!torch.vtensor<[?],si64>
iree-compile --iree-hal-target-backends=llvm-cpu model.linalg.mlir -o model.vmfb --dump-compilation-phases-to=./tmp/
failed to translate executables
failed to translate executables
model.linalg.mlir:21:10: error: 'memref.alloca' op expected no unbounded stack allocations
%1 = tensor.empty(%dim) : tensor<?xi64>
^
model.linalg.mlir:10:3: note: called from
func.func @main_graph(%arg0: tensor<?xi1>) -> tensor<1x1xi64> {
^
model.linalg.mlir:21:10: note: see current operation: %14 = "memref.alloca"(%11) <{alignment = 64 : i64, operandSegmentSizes = array<i32: 1, 0>}> : (index) -> memref<?xi64>
Running MaskedFillTensorIntValueStaticModule_basic...
*** RUNNING TEST: MaskedScatterStaticBasic_basic ***
Compiling MaskedScatterStaticBasic_basic...
/proj/gdba/shark/chi/src/torch-mlir/mlir_venv/lib/python3.10/site-packages/torch/onnx/symbolic_opset10.py:513: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
return g.op("Constant", value_t=torch.tensor(list_or_value))
====================
ONNX RAW IR
module {
func.func @main_graph(%arg0: !torch.vtensor<[4,4],f32>, %arg1: !torch.vtensor<[4,4],i1>, %arg2: !torch.vtensor<[8,8],f32>) -> !torch.vtensor<[4,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.6.0"} {