Skip to content

Instantly share code, notes, and snippets.

func.func @torch.aten._index_put_impl(%input: !torch.vtensor<[1,4],si64>, %index: !torch.vtensor<[3],si64>, %fillValues: !torch.vtensor<[],si64>) -> !torch.vtensor<[1,4],si64>{
%false = torch.constant.bool false
%none = torch.constant.none
%indices = torch.prim.ListConstruct %none, %index : (!torch.none, !torch.vtensor<[3],si64>) -> !torch.list<optional<vtensor>>
%out = torch.aten._index_put_impl %input, %indices, %fillValues, %false, %false : !torch.vtensor<[1,4],si64>, !torch.list<optional<vtensor>>, !torch.vtensor<[],si64>, !torch.bool, !torch.bool -> !torch.vtensor<[1,4],si64>
return %out : !torch.vtensor<[1,4],si64>
}
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
import tempfile
import torch_mlir
def prepare_sentence_tokens(hf_model: str, sentence: str):
tokenizer = AutoTokenizer.from_pretrained(hf_model)
# pip install transformers==4.26.0
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
import tempfile
import torch_mlir
class HfMaskedLM(torch.nn.Module):
#loc = loc(unknown)
module attributes {torch.debug_module_name = "HuggingFaceLanguage"} {
func.func @forward(%arg0: !torch.vtensor<[?,?],si64> loc(unknown)) -> !torch.vtensor<[?,2],f32> {
%int768 = torch.constant.int 768 loc(#loc1)
%true = torch.constant.bool true loc(#loc1)
%float1.000000e00 = torch.constant.float 1.000000e+00 loc(#loc2)
%none = torch.constant.none loc(#loc)
%int0 = torch.constant.int 0 loc(#loc3)
%int1 = torch.constant.int 1 loc(#loc3)
%false = torch.constant.bool false loc(#loc4)
#loc = loc(unknown)
module attributes {torch.debug_module_name = "_lambda"} {
func.func @forward(%arg0: !torch.vtensor<[1,128],si64> loc(unknown)) -> !torch.vtensor<[1,2],f32> {
%int128 = torch.constant.int 128 loc(#loc1)
%int127 = torch.constant.int 127 loc(#loc2)
%int1 = torch.constant.int 1 loc(#loc3)
%true = torch.constant.bool true loc(#loc4)
%int0 = torch.constant.int 0 loc(#loc5)
%int2 = torch.constant.int 2 loc(#loc6)
%none = torch.constant.none loc(#loc)
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
import tempfile
import torch_mlir
def prepare_sentence_tokens(hf_model: str, sentence: str):
spec_scatter.txt
size(inputs) = size(updates) = size(results) = 1 = N
input = inputs[0]
update = updates[0]
result = results[0]
// %input:
// [
// RUN: stablehlo-translate --interpret -split-input-file %s
func.func @scatter_op_test() {
%inputs = stablehlo.constant dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>
%scatter_indices = stablehlo.constant dense<[[0, 1], [0, 2], [0, 3]]> : tensor<3x2xi64>
%updates = stablehlo.constant dense<[[4], [5], [6]]> : tensor<3x1xi64>
%result = "stablehlo.scatter"(%inputs, %scatter_indices, %updates) ({
^bb0(%arg0: tensor<i64>):
stablehlo.return %arg0 : tensor<i64>
}) {
import torch
# a = torch.tensor([[0, 1, 2, 3]])
# a[..., 1:] = torch.tensor([4, 5, 6])
# = a[..., 1:4] = torch.tensor([4, 5, 6])
# = a[[0, 0, 0], [1, 2, 3]] = torch.tensor([4, 5, 6]) # tensor([[0, 4, 5,
# 6]]) = torch.ops.aten.index_put(torch.tensor([[0, 1, 2, 3]]), # input
# (torch.tensor([0, 0, 0]), torch.tensor([1, 2,
# 3])), # indicies torch.tensor([4, 5, 6])) #
# value
# = torch.ops.aten.index_put(torch.tensor([[0, 1, 2, 3]]), # input
(mlir_venv) ➜ torch-mlir git:(decompose) ✗ torch-mlir-opt --convert-torch-to-stablehlo /nodclouddata/chi/src/models/t5/slicecopy/test_indexputhackedtwin.mlir --debug
Args: torch-mlir-opt --convert-torch-to-stablehlo /nodclouddata/chi/src/models/t5/slicecopy/test_indexputhackedtwin.mlir --debug
Load new dialect in Context builtin
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ShapedType)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::MemRefLayoutAttrInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::TypedAttr)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ElementsAttr)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::DistinctAttr)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::BytecodeOpInterface)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::SymbolOpInterface)