Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Last active April 6, 2023 16:24
Show Gist options
  • Save AmosLewis/85b7c19409bfdbe45f216d689e947578 to your computer and use it in GitHub Desktop.
Save AmosLewis/85b7c19409bfdbe45f216d689e947578 to your computer and use it in GitHub Desktop.
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
import tempfile
import torch_mlir
class Test(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, input_ids, decoder_input_ids):
shifted_input_ids = decoder_input_ids.new_zeros(decoder_input_ids.shape) # tensor([[0, 0, 0, 0]])
shifted_input_ids[..., 0] = 0 # tensor([[ 0, 6536, 504, 24]])
return shifted_input_ids
model = Test()
input_ids = torch.tensor([[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14]])
decoder_input_ids = torch.tensor([[6536, 504, 24, 1]])
test_inputs = (input_ids, decoder_input_ids)
outputs = model(*test_inputs)
print("model(test_input): ")
print(outputs)
fx_g = make_fx(
model,
decomposition_table=get_decompositions([]),
)(*test_inputs)
print("fx_g.graph: ")
print(fx_g.graph)
# graph():
# %arg0_1 : [#users=0] = placeholder[target=arg0_1]
# %arg1_1 : [#users=1] = placeholder[target=arg1_1]
# %new_zeros : [#users=2] = call_function[target=torch.ops.aten.new_zeros.default](args = (%arg1_1, [1, 4]), kwargs = {dtype: torch.int64, layout: torch.strided, device: cpu, pin_memory: False})
# %_tensor_constant0 : [#users=1] = get_attr[target=_tensor_constant0]
# %lift_fresh_copy : [#users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%_tensor_constant0,), kwargs = {})
# %select : [#users=1] = call_function[target=torch.ops.aten.select.int](args = (%new_zeros, 1, 0), kwargs = {})
# %fill_ : [#users=0] = call_function[target=torch.ops.aten.fill_.Tensor](args = (%select, %lift_fresh_copy), kwargs = {})
# return new_zeros
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
fx_g.recompile()
def strip_overloads(gm):
"""
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
Args:
gm(fx.GraphModule): The input Fx graph module to be modified
"""
for node in gm.graph.nodes:
if isinstance(node.target, torch._ops.OpOverload):
node.target = node.target.overloadpacket
gm.recompile()
strip_overloads(fx_g)
ts_g = torch.jit.script(fx_g)
print("ts_g.graph: ")
print(ts_g.graph)
# ts_g.graph:
# graph(%self : __torch__.torch.fx.graph_module._lambda,
# %arg0_1 : Tensor,
# %arg1_1.1 : Tensor):
# %11 : bool = prim::Constant[value=0]() # <eval_with_key>.2:5:144
# %37 : Device = prim::Constant[value="cpu"]()
# %4 : int = prim::Constant[value=1]() # <eval_with_key>.2:5:50
# %5 : int = prim::Constant[value=4]() # <eval_with_key>.2:5:53
# %19 : int = prim::Constant[value=0]() # <eval_with_key>.2:8:49
# %6 : int[] = prim::ListConstruct(%4, %5)
# %new_zeros.1 : Tensor = aten::new_zeros(%arg1_1.1, %6, %5, %19, %37, %11) # <eval_with_key>.2:5:16
# %_tensor_constant0.1 : Tensor = prim::GetAttr[name="_tensor_constant0"](%self)
# %lift_fresh_copy.1 : Tensor = aten::lift_fresh_copy(%_tensor_constant0.1) # <eval_with_key>.2:7:22
# %select.1 : Tensor = aten::select(%new_zeros.1, %4, %19) # <eval_with_key>.2:8:13
# %fill_ : Tensor = aten::fill_(%select.1, %lift_fresh_copy.1) # <eval_with_key>.2:9:12
# return (%new_zeros.1)
module = torch_mlir.compile(
ts_g,
(input_ids, decoder_input_ids),
torch_mlir.OutputType.RAW,
use_tracing=True,
verbose=False,
)
import os
mlir_str = module.operation.get_asm()
dir=tempfile.gettempdir()
with open(os.path.join(dir, "test_masked_fill_torchscript_0327_transformers4.26.0.mlir"), "w") as mlir_file:
mlir_file.write(mlir_str)
@AmosLewis
Copy link
Author

test_masked_fill_torchscript_0327_transformers4.26.0.mlir

module attributes {torch.debug_module_name = "_lambda"} {
  func.func private @__torch__.torch.fx.graph_module._lambda.__code_getter(%arg0: !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda">) -> !torch.str {
    %2 = torch.prim.GetAttr %arg0["_code"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.str
    return %2 : !torch.str
  }
  func.func private @__torch__.torch.fx.graph_module._lambda.forward(%arg0: !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda">, %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[1,15],si64>}, %arg2: !torch.tensor {torch.type_bound = !torch.vtensor<[1,4],si64>}) -> !torch.tensor {
    %false = torch.constant.bool false
    %cpu = torch.constant.device "cpu"
    %int1 = torch.constant.int 1
    %int4 = torch.constant.int 4
    %int0 = torch.constant.int 0
    %2 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
    %3 = torch.aten.new_zeros %arg2, %2, %int4, %int0, %cpu, %false : !torch.tensor, !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.tensor
    %4 = torch.prim.GetAttr %arg0["_tensor_constant0"] : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda"> -> !torch.tensor
    %5 = torch.aten.lift_fresh_copy %4 : !torch.tensor -> !torch.tensor
    %6 = torch.aten.select.int %3, %int1, %int0 : !torch.tensor, !torch.int, !torch.int -> !torch.tensor
    %7 = torch.aten.fill_.Tensor %6, %5 : !torch.tensor, !torch.tensor -> !torch.tensor
    return %3 : !torch.tensor
  }
  torch.class_type @__torch__.torch.fx.graph_module._lambda {
    torch.attr private "_tensor_constant0" : !torch.tensor
    torch.attr private "training" : !torch.bool
    torch.attr private "_is_full_backward_hook" : !torch.optional<bool>
    torch.attr private "_code" : !torch.str
    torch.method private "__code_getter", @__torch__.torch.fx.graph_module._lambda.__code_getter
    torch.method "forward", @__torch__.torch.fx.graph_module._lambda.forward
  }
  %0 = torch.tensor.literal(dense<0> : tensor<si64>) : !torch.tensor<[],si64>
  %true = torch.constant.bool true
  %none = torch.constant.none
  %str = torch.constant.str "\0A\0A\0Adef forward(self, arg0_1, arg1_1):\0A    new_zeros = torch.ops.aten.new_zeros(arg1_1, [1, 4], dtype = torch.int64, layout = torch.strided, device = device(type='cpu'), pin_memory = False);  arg1_1 = None\0A    _tensor_constant0 = self._tensor_constant0\0A    lift_fresh_copy = torch.ops.aten.lift_fresh_copy(_tensor_constant0);  _tensor_constant0 = None\0A    select = torch.ops.aten.select(new_zeros, 1, 0)\0A    fill_ = torch.ops.aten.fill_(select, lift_fresh_copy);  select = lift_fresh_copy = None\0A    return new_zeros\0A    "
  %1 = torch.nn_module {
    torch.slot "_tensor_constant0", %0 : !torch.tensor<[],si64>
    torch.slot "training", %true : !torch.bool
    torch.slot "_is_full_backward_hook", %none : !torch.none
    torch.slot "_code", %str : !torch.str
  } : !torch.nn.Module<"__torch__.torch.fx.graph_module._lambda">
}

@AmosLewis
Copy link
Author

AmosLewis commented Mar 27, 2023

ERROR

➜  t5small git:(main) ✗ torch-mlir-opt -pass-pipeline='builtin.module(torchscript-module-to-torch-backend-pipeline{backend-legal-ops=torch.aten.flatten.using_ints,torch.aten.native_layer_norm,torch.aten.linear})' ./test_masked_fill_torchscript_0327_transformers4.26.0.mlir  -mlir-print-ir-after-failure -mlir-disable-threading
./test_masked_fill_torchscript_0327_transformers4.26.0.mlir:13:10: error: unsupported by backend contract: tensor with unknown rank
    %3 = torch.aten.new_zeros %arg2, %2, %int4, %int0, %cpu, %false : !torch.tensor, !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.tensor
         ^
./test_masked_fill_torchscript_0327_transformers4.26.0.mlir:13:10: note: see current operation: %9 = "torch.tensor_static_info_cast"(%8) : (!torch.vtensor<[1,4],si64>) -> !torch.vtensor<*,si64>
./test_masked_fill_torchscript_0327_transformers4.26.0.mlir:13:10: note: this is likely due to a missing transfer function in abstract_interp_lib_gen.py
// -----// IR Dump After LowerToBackendContract Failed (torch-lower-to-backend-contract) //----- //
module attributes {torch.debug_module_name = "_lambda"} {
  func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.vtensor<*,si64> {
    %int1 = torch.constant.int 1
    %int0 = torch.constant.int 0
    %false = torch.constant.bool false
    %int4 = torch.constant.int 4
    %none = torch.constant.none
    %0 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
    %cpu = torch.constant.device "cpu"
    %1 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
    %2 = torch.aten.zeros %1, %int4, %int0, %cpu, %false : !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[1,4],si64>
    %3 = torch.tensor_static_info_cast %2 : !torch.vtensor<[1,4],si64> to !torch.vtensor<*,si64>
    %4 = torch.copy.to_tensor %3 : !torch.tensor<*,si64>
    %5 = torch.aten.clone %0, %none : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64>
    %6 = torch.aten.slice.Tensor %4, %int1, %int0, %int1, %int1 : !torch.tensor<*,si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor<[1,1],si64>
    %7 = torch.aten.squeeze.dim %6, %int1 : !torch.tensor<[1,1],si64>, !torch.int -> !torch.tensor<[1],si64>
    %8 = torch.tensor_static_info_cast %7 : !torch.tensor<[1],si64> to !torch.tensor<*,si64>
    %9 = torch.copy.to_vtensor %8 : !torch.vtensor<*,si64>
    %10 = torch.aten.fill.Tensor %9, %5 : !torch.vtensor<*,si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[1],si64>
    %11 = torch.tensor_static_info_cast %10 : !torch.vtensor<[1],si64> to !torch.vtensor<*,si64>
    torch.overwrite.tensor.contents %11 overwrites %8 : !torch.vtensor<*,si64>, !torch.tensor<*,si64>
    %12 = torch.copy.to_vtensor %4 : !torch.vtensor<*,si64>
    return %12 : !torch.vtensor<*,si64>
  }
}

@AmosLewis
Copy link
Author

Done

module attributes {torch.debug_module_name = "_lambda"} {
  func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.vtensor<[1,4],si64> {
    %int1 = torch.constant.int 1
    %int0 = torch.constant.int 0
    %false = torch.constant.bool false
    %int4 = torch.constant.int 4
    %none = torch.constant.none
    %0 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
    %cpu = torch.constant.device "cpu"
    %1 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
    %2 = torch.aten.zeros %1, %int4, %int0, %cpu, %false : !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[1,4],si64>
    %3 = torch.aten.clone %0, %none : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64>
    %4 = torch.aten.slice.Tensor %2, %int1, %int0, %int1, %int1 : !torch.vtensor<[1,4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1],si64>
    %5 = torch.aten.squeeze.dim %4, %int1 : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.vtensor<[1],si64>
    %6 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64>
    %7 = torch.prim.ListConstruct %6 : (!torch.vtensor<[],si64>) -> !torch.list<optional<vtensor>>
    %8 = torch.aten._index_put_impl %2, %7, %3, %false, %false : !torch.vtensor<[1,4],si64>, !torch.list<optional<vtensor>>, !torch.vtensor<[],si64>, !torch.bool, !torch.bool -> !torch.vtensor<[1,4],si64>
    return %8 : !torch.vtensor<[1,4],si64>
  }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment