Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Last active August 22, 2023 01:39
Show Gist options
  • Save AmosLewis/c6007c2154fedd51081faaee903a1b2c to your computer and use it in GitHub Desktop.
Save AmosLewis/c6007c2154fedd51081faaee903a1b2c to your computer and use it in GitHub Desktop.
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
import tempfile
import torch_mlir
class Test(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, input_ids, decoder_input_ids):
shifted_input_ids = decoder_input_ids.new_zeros(decoder_input_ids.shape) # tensor([[0, 0, 0, 0]])
shifted_input_ids[..., 1:] = decoder_input_ids[..., :-1].clone() # tensor([[6536, 504, 24]])
return shifted_input_ids
model = Test()
input_ids = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]])
decoder_input_ids = torch.tensor([[6536, 504, 24, 1]])
test_inputs = (input_ids, decoder_input_ids)
outputs = model(*test_inputs)
print("model(test_input): ")
print(outputs)
fx_g = make_fx(
model,
decomposition_table=get_decompositions(
[
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
]
),
)(*test_inputs)
# print("fx_g.graph: ")
# print(fx_g.graph)
# graph():
# %arg0_1 : [#users=0] = placeholder[target=arg0_1]
# %arg1_1 : [#users=1] = placeholder[target=arg1_1]
# %new_zeros : [#users=2] = call_function[target=torch.ops.aten.new_zeros.default](args = (%arg1_1, [1, 4]), kwargs = {dtype: torch.int64, layout: torch.strided, device: cpu, pin_memory: False})
# %_tensor_constant0 : [#users=1] = get_attr[target=_tensor_constant0]
# %lift_fresh_copy : [#users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%_tensor_constant0,), kwargs = {})
# %select : [#users=1] = call_function[target=torch.ops.aten.select.int](args = (%new_zeros, 1, 0), kwargs = {})
# %fill_ : [#users=0] = call_function[target=torch.ops.aten.fill_.Tensor](args = (%select, %lift_fresh_copy), kwargs = {})
# return new_zeros
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
fx_g.recompile()
def strip_overloads(gm):
"""
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
Args:
gm(fx.GraphModule): The input Fx graph module to be modified
"""
for node in gm.graph.nodes:
if isinstance(node.target, torch._ops.OpOverload):
node.target = node.target.overloadpacket
gm.recompile()
strip_overloads(fx_g)
ts_g = torch.jit.script(fx_g)
# print("ts_g.graph: ")
# print(ts_g.graph)
# ts_g.graph:
# graph(%self : __torch__.torch.fx.graph_module._lambda,
# %arg0_1 : Tensor,
# %arg1_1.1 : Tensor):
# %21 : NoneType = prim::Constant()
# %16 : int = prim::Constant[value=-1]() # <eval_with_key>.2:6:49
# %11 : bool = prim::Constant[value=0]() # <eval_with_key>.2:5:144
# %45 : Device = prim::Constant[value="cpu"]()
# %4 : int = prim::Constant[value=1]() # <eval_with_key>.2:5:50
# %5 : int = prim::Constant[value=4]() # <eval_with_key>.2:5:53
# %14 : int = prim::Constant[value=0]() # <eval_with_key>.2:6:46
# %25 : int = prim::Constant[value=9223372036854775807]() # <eval_with_key>.2:8:52
# %6 : int[] = prim::ListConstruct(%4, %5)
# %new_zeros.1 : Tensor = aten::new_zeros(%arg1_1.1, %6, %5, %14, %45, %11) # <eval_with_key>.2:5:16
# %slice_1.1 : Tensor = aten::slice(%arg1_1.1, %4, %14, %16, %4) # <eval_with_key>.2:6:14
# %clone.1 : Tensor = aten::clone(%slice_1.1, %21) # <eval_with_key>.2:7:12
# %slice_2.1 : Tensor = aten::slice(%new_zeros.1, %4, %4, %25, %4) # <eval_with_key>.2:8:14
# %copy_ : Tensor = aten::copy_(%slice_2.1, %clone.1, %11) # <eval_with_key>.2:9:12
# return (%new_zeros.1)
module = torch_mlir.compile(
ts_g,
(input_ids, decoder_input_ids),
torch_mlir.OutputType.RAW,
use_tracing=True,
verbose=False,
)
import os
mlir_str = module.operation.get_asm()
dir=tempfile.gettempdir()
with open(os.path.join(dir, "test_slicecopy_torchscript_0327_transformers4.26.0.mlir"), "w") as mlir_file:
mlir_file.write(mlir_str)
@AmosLewis
Copy link
Author

index_put.hacked_twin

module attributes {torch.debug_module_name = "_lambda"} {
  func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.vtensor<[1,4],si64> {
    %none = torch.constant.none
    %int4 = torch.constant.int 4
    %int1 = torch.constant.int 1
    %int0 = torch.constant.int 0
    %int-1 = torch.constant.int -1
    %false = torch.constant.bool false
    %0 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
    %1 = torch.aten.zeros %0, %int4, %none, %none, %false : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1,4],si64>
    %2 = torch.aten.slice.Tensor %arg1, %int1, %int0, %int-1, %int1 : !torch.vtensor<[1,4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,3],si64>
    %3 = torch.aten.clone %2, %none : !torch.vtensor<[1,3],si64>, !torch.none -> !torch.vtensor<[1,3],si64>
    %4 = torch.aten.arange.start_step %int1, %int4, %int1, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3],si64>
    %5 = torch.aten.arange.start_step %int0, %int1, %int1, %int4, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si64>
    %6 = torch.aten.unsqueeze %5, %int-1 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64>
    %7 = torch.prim.ListConstruct %6, %4 : (!torch.vtensor<[1,1],si64>, !torch.vtensor<[3],si64>) -> !torch.list<vtensor>
    %8 = torch.aten.index_put.hacked_twin %1, %7, %3, %false : !torch.vtensor<[1,4],si64>, !torch.list<vtensor>, !torch.vtensor<[1,3],si64>, !torch.bool -> !torch.vtensor<[1,4],si64>
    return %8 : !torch.vtensor<[1,4],si64>
  }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment