Last active
April 6, 2023 16:24
-
-
Save AmosLewis/85b7c19409bfdbe45f216d689e947578 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.fx.experimental.proxy_tensor import make_fx | |
from torch._decomp import get_decompositions | |
import tempfile | |
import torch_mlir | |
class Test(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, input_ids, decoder_input_ids): | |
shifted_input_ids = decoder_input_ids.new_zeros(decoder_input_ids.shape) # tensor([[0, 0, 0, 0]]) | |
shifted_input_ids[..., 0] = 0 # tensor([[ 0, 6536, 504, 24]]) | |
return shifted_input_ids | |
model = Test() | |
input_ids = torch.tensor([[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14]]) | |
decoder_input_ids = torch.tensor([[6536, 504, 24, 1]]) | |
test_inputs = (input_ids, decoder_input_ids) | |
outputs = model(*test_inputs) | |
print("model(test_input): ") | |
print(outputs) | |
fx_g = make_fx( | |
model, | |
decomposition_table=get_decompositions([]), | |
)(*test_inputs) | |
print("fx_g.graph: ") | |
print(fx_g.graph) | |
# graph(): | |
# %arg0_1 : [#users=0] = placeholder[target=arg0_1] | |
# %arg1_1 : [#users=1] = placeholder[target=arg1_1] | |
# %new_zeros : [#users=2] = call_function[target=torch.ops.aten.new_zeros.default](args = (%arg1_1, [1, 4]), kwargs = {dtype: torch.int64, layout: torch.strided, device: cpu, pin_memory: False}) | |
# %_tensor_constant0 : [#users=1] = get_attr[target=_tensor_constant0] | |
# %lift_fresh_copy : [#users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%_tensor_constant0,), kwargs = {}) | |
# %select : [#users=1] = call_function[target=torch.ops.aten.select.int](args = (%new_zeros, 1, 0), kwargs = {}) | |
# %fill_ : [#users=0] = call_function[target=torch.ops.aten.fill_.Tensor](args = (%select, %lift_fresh_copy), kwargs = {}) | |
# return new_zeros | |
fx_g.graph.set_codegen(torch.fx.graph.CodeGen()) | |
fx_g.recompile() | |
def strip_overloads(gm): | |
""" | |
Modifies the target of graph nodes in :attr:`gm` to strip overloads. | |
Args: | |
gm(fx.GraphModule): The input Fx graph module to be modified | |
""" | |
for node in gm.graph.nodes: | |
if isinstance(node.target, torch._ops.OpOverload): | |
node.target = node.target.overloadpacket | |
gm.recompile() | |
strip_overloads(fx_g) | |
ts_g = torch.jit.script(fx_g) | |
print("ts_g.graph: ") | |
print(ts_g.graph) | |
# ts_g.graph: | |
# graph(%self : __torch__.torch.fx.graph_module._lambda, | |
# %arg0_1 : Tensor, | |
# %arg1_1.1 : Tensor): | |
# %11 : bool = prim::Constant[value=0]() # <eval_with_key>.2:5:144 | |
# %37 : Device = prim::Constant[value="cpu"]() | |
# %4 : int = prim::Constant[value=1]() # <eval_with_key>.2:5:50 | |
# %5 : int = prim::Constant[value=4]() # <eval_with_key>.2:5:53 | |
# %19 : int = prim::Constant[value=0]() # <eval_with_key>.2:8:49 | |
# %6 : int[] = prim::ListConstruct(%4, %5) | |
# %new_zeros.1 : Tensor = aten::new_zeros(%arg1_1.1, %6, %5, %19, %37, %11) # <eval_with_key>.2:5:16 | |
# %_tensor_constant0.1 : Tensor = prim::GetAttr[name="_tensor_constant0"](%self) | |
# %lift_fresh_copy.1 : Tensor = aten::lift_fresh_copy(%_tensor_constant0.1) # <eval_with_key>.2:7:22 | |
# %select.1 : Tensor = aten::select(%new_zeros.1, %4, %19) # <eval_with_key>.2:8:13 | |
# %fill_ : Tensor = aten::fill_(%select.1, %lift_fresh_copy.1) # <eval_with_key>.2:9:12 | |
# return (%new_zeros.1) | |
module = torch_mlir.compile( | |
ts_g, | |
(input_ids, decoder_input_ids), | |
torch_mlir.OutputType.RAW, | |
use_tracing=True, | |
verbose=False, | |
) | |
import os | |
mlir_str = module.operation.get_asm() | |
dir=tempfile.gettempdir() | |
with open(os.path.join(dir, "test_masked_fill_torchscript_0327_transformers4.26.0.mlir"), "w") as mlir_file: | |
mlir_file.write(mlir_str) | |
ERROR
➜ t5small git:(main) ✗ torch-mlir-opt -pass-pipeline='builtin.module(torchscript-module-to-torch-backend-pipeline{backend-legal-ops=torch.aten.flatten.using_ints,torch.aten.native_layer_norm,torch.aten.linear})' ./test_masked_fill_torchscript_0327_transformers4.26.0.mlir -mlir-print-ir-after-failure -mlir-disable-threading
./test_masked_fill_torchscript_0327_transformers4.26.0.mlir:13:10: error: unsupported by backend contract: tensor with unknown rank
%3 = torch.aten.new_zeros %arg2, %2, %int4, %int0, %cpu, %false : !torch.tensor, !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.tensor
^
./test_masked_fill_torchscript_0327_transformers4.26.0.mlir:13:10: note: see current operation: %9 = "torch.tensor_static_info_cast"(%8) : (!torch.vtensor<[1,4],si64>) -> !torch.vtensor<*,si64>
./test_masked_fill_torchscript_0327_transformers4.26.0.mlir:13:10: note: this is likely due to a missing transfer function in abstract_interp_lib_gen.py
// -----// IR Dump After LowerToBackendContract Failed (torch-lower-to-backend-contract) //----- //
module attributes {torch.debug_module_name = "_lambda"} {
func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.vtensor<*,si64> {
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%false = torch.constant.bool false
%int4 = torch.constant.int 4
%none = torch.constant.none
%0 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%cpu = torch.constant.device "cpu"
%1 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.aten.zeros %1, %int4, %int0, %cpu, %false : !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[1,4],si64>
%3 = torch.tensor_static_info_cast %2 : !torch.vtensor<[1,4],si64> to !torch.vtensor<*,si64>
%4 = torch.copy.to_tensor %3 : !torch.tensor<*,si64>
%5 = torch.aten.clone %0, %none : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64>
%6 = torch.aten.slice.Tensor %4, %int1, %int0, %int1, %int1 : !torch.tensor<*,si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor<[1,1],si64>
%7 = torch.aten.squeeze.dim %6, %int1 : !torch.tensor<[1,1],si64>, !torch.int -> !torch.tensor<[1],si64>
%8 = torch.tensor_static_info_cast %7 : !torch.tensor<[1],si64> to !torch.tensor<*,si64>
%9 = torch.copy.to_vtensor %8 : !torch.vtensor<*,si64>
%10 = torch.aten.fill.Tensor %9, %5 : !torch.vtensor<*,si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[1],si64>
%11 = torch.tensor_static_info_cast %10 : !torch.vtensor<[1],si64> to !torch.vtensor<*,si64>
torch.overwrite.tensor.contents %11 overwrites %8 : !torch.vtensor<*,si64>, !torch.tensor<*,si64>
%12 = torch.copy.to_vtensor %4 : !torch.vtensor<*,si64>
return %12 : !torch.vtensor<*,si64>
}
}
Done
module attributes {torch.debug_module_name = "_lambda"} {
func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.vtensor<[1,4],si64> {
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%false = torch.constant.bool false
%int4 = torch.constant.int 4
%none = torch.constant.none
%0 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
%cpu = torch.constant.device "cpu"
%1 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.aten.zeros %1, %int4, %int0, %cpu, %false : !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[1,4],si64>
%3 = torch.aten.clone %0, %none : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64>
%4 = torch.aten.slice.Tensor %2, %int1, %int0, %int1, %int1 : !torch.vtensor<[1,4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1],si64>
%5 = torch.aten.squeeze.dim %4, %int1 : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.vtensor<[1],si64>
%6 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64>
%7 = torch.prim.ListConstruct %6 : (!torch.vtensor<[],si64>) -> !torch.list<optional<vtensor>>
%8 = torch.aten._index_put_impl %2, %7, %3, %false, %false : !torch.vtensor<[1,4],si64>, !torch.list<optional<vtensor>>, !torch.vtensor<[],si64>, !torch.bool, !torch.bool -> !torch.vtensor<[1,4],si64>
return %8 : !torch.vtensor<[1,4],si64>
}
}
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
test_masked_fill_torchscript_0327_transformers4.26.0.mlir