Skip to content

Instantly share code, notes, and snippets.

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
import tempfile
import torch_mlir
def prepare_sentence_tokens(hf_model: str, sentence: str):
tokenizer = AutoTokenizer.from_pretrained(hf_model)
module attributes {torch.debug_module_name = "_lambda"} {
func.func @forward(%arg0: tensor<1x128xi64>) -> tensor<1x2xf32> {
%0 = "tosa.const"() {value = dense<[[65536, 512, 1]]> : tensor<1x3xi32>} : () -> tensor<1x3xi32>
%1 = "tosa.const"() {value = dense_resource<__elided__> : tensor<2x768xf32>} : () -> tensor<2x768xf32>
%2 = "tosa.const"() {value = dense_resource<__elided__> : tensor<768x768xf32>} : () -> tensor<768x768xf32>
%3 = "tosa.const"() {value = dense_resource<__elided__> : tensor<768xf32>} : () -> tensor<768xf32>
%4 = "tosa.const"() {value = dense_resource<__elided__> : tensor<768x3072xf32>} : () -> tensor<768x3072xf32>
%5 = "tosa.const"() {value = dense_resource<__elided__> : tensor<3072xf32>} : () -> tensor<3072xf32>
%6 = "tosa.const"() {value = dense_resource<__elided__> : tensor<3072x768xf32>} : () -> tensor<3072x768xf32>
%7 = "tosa.const"() {value = dense<-3.40282347E+38> : tensor<f32>} : () -> tensor<f32>
func.func @torch.aten.gather(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],si64>) -> !torch.vtensor<[?,?,?],f32> {
%int-1 = torch.constant.int -1
%false = torch.constant.bool false
%0 = torch.aten.gather %arg0, %int-1, %arg1, %false : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.vtensor<[?,?,?,?],si64>, !torch.bool -> !torch.vtensor<[?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?],f32>
}
➜ deberta git:(main) ✗ torch-mlir-opt -pass-pipeline='builtin.module(torch-backend-to-tosa-backend-pipeline)' /tmp/_lambda.mlir --mlir-print-ir-after-failure -mlir-disable-threading
<eval_with_key>.2:8:54: error: failed to legalize operation 'torch.constant.int'
<eval_with_key>.2:8:54: note: see current operation: %1 = "torch.constant.int"() {value = 0 : i64} : () -> !torch.int
// -----// IR Dump After FinalizingBackendTypeConversion Failed (torch-finalizing-backend-type-conversion) //----- //
func.func @forward(%arg0: tensor<1x128xi64>) -> tensor<1x2xf32> {
%0 = "tosa.const"() {value = dense<[[65536, 512, 1]]> : tensor<1x3xi32>} : () -> tensor<1x3xi32>
%int0 = torch.constant.int 0
%1 = "tosa.const"() {value = dense_resource<__elided__> : tensor<2x768xf32>} : () -> tensor<2x768xf32>
%2 = "tosa.const"() {value = dense_resource<__elided__> : tensor<768x768xf32>} : () -> tensor<768x768xf32>
%3 = "tosa.const"() {value = dense_resource<__elided__> : tensor<768xf32>} : () -> tensor<768xf32>
#loc = loc(unknown)
module attributes {torch.debug_module_name = "_lambda"} {
func.func @forward(%arg0: !torch.vtensor<[1,128],si64> loc(unknown)) -> !torch.vtensor<[1,2],f32> {
%int0 = torch.constant.int 0 loc(#loc1)
%int1 = torch.constant.int 1 loc(#loc2)
%true = torch.constant.bool true loc(#loc3)
%float1.000000e00 = torch.constant.float 1.000000e+00 loc(#loc3)
%none = torch.constant.none loc(#loc)
%int11 = torch.constant.int 11 loc(#loc4)
%false = torch.constant.bool false loc(#loc5)
func.func @torch.aten.gather(%arg0: !torch.vtensor<[12,128,512],f32>, %arg1: !torch.vtensor<[1,128,128],si64>) -> !torch.vtensor<[12,128,128],f32> {
%int-1 = torch.constant.int -1
%false = torch.constant.bool false
%0 = torch.aten.gather %arg0, %int-1, %arg1, %false : !torch.vtensor<[12,128,512],f32>, !torch.int, !torch.vtensor<[1,128,128],si64>, !torch.bool -> !torch.vtensor<[12,128,128],f32>
return %0 : !torch.vtensor<[12,128,128],f32>
}
func.func @torch.aten.gather(%arg0: !torch.vtensor<[1,4,3],f32>, %arg1: !torch.vtensor<[1,4,2],si64>) -> !torch.vtensor<[1,4,2],f32> {
%int-1 = torch.constant.int -1
%false = torch.constant.bool false
%0 = torch.aten.gather %arg0, %int-1, %arg1, %false : !torch.vtensor<[1,4,3],f32>, !torch.int, !torch.vtensor<[1,4,2],si64>, !torch.bool -> !torch.vtensor<[1,4,2],f32>
return %0 : !torch.vtensor<[1,4,2],f32>
}
value1:
vec:[0,0,0,0,0,0,0]
shape: 1*4*2*1
tosa::getConstTensor
tensor([[
[
[0], [0]
],
[
[0], [0]
module {
func.func @torch.aten.gather(%arg0: !torch.vtensor<[1,4,3],f32>, %arg1: !torch.vtensor<[1,4,2,3],i32>) -> !torch.vtensor<[1,4,2],f32> {
%0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,4,3],f32> -> tensor<1x4x3xf32>
%0= torch.tensor([[
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10, 11, 12]
]]) # 1*4*3
%1 = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[1,4,2,3],i32> -> tensor<1x4x2x3xi32>
func.func @torch.aten.gather(%arg0: !torch.vtensor<[1,4,3],f32>, %arg1: !torch.vtensor<[1,4,2,3],si64>) -> !torch.vtensor<[1,4,2],f32> {
%int-1 = torch.constant.int -1
%false = torch.constant.bool false
%0 = torch.aten.gather %arg0, %int-1, %arg1, %false : !torch.vtensor<[1,4,3],f32>, !torch.int, !torch.vtensor<[1,4,2,3],si64>, !torch.bool -> !torch.vtensor<[1,4,2],f32>
return %0 : !torch.vtensor<[1,4,2],f32>
}