Skip to content

Instantly share code, notes, and snippets.

fx_g.graph:
graph():
%arg0_1 : [#users=1] = placeholder[target=arg0_1]
%view : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%arg0_1, [-1, 128]), kwargs = {})
%arange : [#users=1] = call_function[target=torch.ops.aten.arange.start](args = (0, 128), kwargs = {dtype: torch.int64, device: cpu, pin_memory: False})
%unsqueeze : [#users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%arange, 0), kwargs = {})
%view_1 : [#users=1] = call_function[target=torch.ops.aten.view.default](args = (%unsqueeze, [-1, 128]), kwargs = {})
%_param_constant0 : [#users=1] = get_attr[target=_param_constant0]
%embedding : [#users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%_param_constant0, %view), kwargs = {})
%_param_constant1 : [#users=1] = get_attr[target=_param_constant1]
#loc = loc(unknown)
module attributes {torch.debug_module_name = "_lambda"} {
func.func @forward(%arg0: !torch.vtensor<[1,128],si64> loc(unknown)) -> !torch.vtensor<[1,2],f32> {
%float1.000000e00 = torch.constant.float 1.000000e+00 loc(#loc1)
%int0 = torch.constant.int 0 loc(#loc2)
%int1 = torch.constant.int 1 loc(#loc3)
%int-1 = torch.constant.int -1 loc(#loc4)
%true = torch.constant.bool true loc(#loc5)
%none = torch.constant.none loc(#loc)
%false = torch.constant.bool false loc(#loc)
func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[1,128],i1>) -> !torch.vtensor<[1,128],si64> {
%int4 = torch.constant.int 4
%none = torch.constant.none
%false = torch.constant.bool false
%0 = torch.aten.to.dtype %arg0, %int4, %false, %false, %none : !torch.vtensor<[1,128],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,128],si64>
return %0 : !torch.vtensor<[1,128],si64>
}
func.func @torch.aten.broadcast_to(%arg0: !torch.vtensor<[1,1,1,128],i1>) -> !torch.vtensor<[1,1,128,128],i1> {
%int1 = torch.constant.int 1
%int128 = torch.constant.int 128
%1 = torch.prim.ListConstruct %int1, %int1, %int128, %int128 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%0 = torch.aten.broadcast_to %arg0, %1 : !torch.vtensor<[1,1,1,128],i1>, !torch.list<int> -> !torch.vtensor<[1,1,128,128],i1>
return %0 : !torch.vtensor<[1,1,128,128],i1>
}
# pip install transformers==4.26.0
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
import tempfile
import torch_mlir
class HfMaskedLM(torch.nn.Module):
#loc = loc(unknown)
module attributes {torch.debug_module_name = "_lambda"} {
func.func private @__torch__.torch.fx.graph_module._lambda.__code_getter(%arg0: !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> loc(unknown)) -> !torch.str {
%133 = torch.prim.GetAttr %arg0["_code"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.str loc(#loc)
return %133 : !torch.str loc(#loc)
} loc(#loc)
func.func private @__torch__.torch.fx.graph_module._lambda.forward(%arg0: !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> loc(unknown), %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[1,15],si64>} loc(unknown), %arg2: !torch.tensor {torch.type_bound = !torch.vtensor<[1,4],si64>} loc(unknown)) -> !torch.tensor {
%int6 = torch.constant.int 6 loc(#loc1)
%true_0 = torch.constant.bool true loc(#loc2)
%float-3.402820e38 = torch.constant.float -3.4028234663852886E+38 loc(#loc3)
➜ ~ torch-mlir-opt -pass-pipeline='builtin.module(torchscript-module-to-torch-backend-pipeline{backend-legal-ops=torch.aten.flatten.using_ints,torch.aten.native_layer_norm,torch.aten.linear})' /tmp/_lambda.mlir --mlir-print-ir-after-failure -mlir-disable-threading
<eval_with_key>.2:5:16: error: unsupported by backend contract: tensor with unknown rank
<eval_with_key>.2:5:16: note: see current operation: %36 = "torch.tensor_static_info_cast"(%35) : (!torch.vtensor<[1,4],si64>) -> !torch.vtensor<*,si64>
<eval_with_key>.2:5:16: note: this is likely due to a missing transfer function in abstract_interp_lib_gen.py
// -----// IR Dump After LowerToBackendContract Failed (torch-lower-to-backend-contract) //----- //
module attributes {torch.debug_module_name = "_lambda"} {
func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.vtensor<[1,4,32128],f32> {
%int512 = torch.constant.int 512
%int1 = torch.constant.int 1
module attributes {torch.debug_module_name = "_lambda"} {
func.func private @__torch__.torch.fx.graph_module._lambda.__code_getter(%arg0: !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda">) -> !torch.str {
%133 = torch.prim.GetAttr %arg0["_code"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.str
return %133 : !torch.str
}
func.func private @__torch__.torch.fx.graph_module._lambda.forward(%arg0: !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda">, %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[1,15],si64>}, %arg2: !torch.tensor {torch.type_bound = !torch.vtensor<[1,4],si64>}) -> !torch.tensor {
%none_1 = torch.constant.none
%int-1 = torch.constant.int -1
%false = torch.constant.bool false
%cpu = torch.constant.device "cpu"
module attributes {torch.debug_module_name = "_lambda"} {
func.func private @__torch__.torch.fx.graph_module._lambda.forward(%arg0: !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda">, %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[1,15],si64>}, %arg2: !torch.tensor {torch.type_bound = !torch.vtensor<[1,4],si64>}) -> !torch.tensor {
%none_1 = torch.constant.none
%int-1 = torch.constant.int -1
%false = torch.constant.bool false
%cpu = torch.constant.device "cpu"
%int1 = torch.constant.int 1
%int4 = torch.constant.int 4
%int0 = torch.constant.int 0
%int-100 = torch.constant.int -100
module attributes {torch.debug_module_name = "_lambda"} {
func.func private @__torch__.torch.fx.graph_module._lambda.forward(%arg0: !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda">, %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[1,15],si64>}, %arg2: !torch.tensor {torch.type_bound = !torch.vtensor<[1,4],si64>}) -> !torch.tensor {
%none_1 = torch.constant.none
%int-1 = torch.constant.int -1
%false = torch.constant.bool false
%cpu = torch.constant.device "cpu"
%int1 = torch.constant.int 1
%int4 = torch.constant.int 4
%int0 = torch.constant.int 0
%int-100 = torch.constant.int -100