Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Last active April 6, 2023 16:24
Show Gist options
  • Save AmosLewis/85b7c19409bfdbe45f216d689e947578 to your computer and use it in GitHub Desktop.
Save AmosLewis/85b7c19409bfdbe45f216d689e947578 to your computer and use it in GitHub Desktop.
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
import tempfile
import torch_mlir
class Test(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, input_ids, decoder_input_ids):
shifted_input_ids = decoder_input_ids.new_zeros(decoder_input_ids.shape) # tensor([[0, 0, 0, 0]])
shifted_input_ids[..., 0] = 0 # tensor([[ 0, 6536, 504, 24]])
return shifted_input_ids
model = Test()
input_ids = torch.tensor([[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14]])
decoder_input_ids = torch.tensor([[6536, 504, 24, 1]])
test_inputs = (input_ids, decoder_input_ids)
outputs = model(*test_inputs)
print("model(test_input): ")
print(outputs)
fx_g = make_fx(
model,
decomposition_table=get_decompositions([]),
)(*test_inputs)
print("fx_g.graph: ")
print(fx_g.graph)
# graph():
# %arg0_1 : [#users=0] = placeholder[target=arg0_1]
# %arg1_1 : [#users=1] = placeholder[target=arg1_1]
# %new_zeros : [#users=2] = call_function[target=torch.ops.aten.new_zeros.default](args = (%arg1_1, [1, 4]), kwargs = {dtype: torch.int64, layout: torch.strided, device: cpu, pin_memory: False})
# %_tensor_constant0 : [#users=1] = get_attr[target=_tensor_constant0]
# %lift_fresh_copy : [#users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%_tensor_constant0,), kwargs = {})
# %select : [#users=1] = call_function[target=torch.ops.aten.select.int](args = (%new_zeros, 1, 0), kwargs = {})
# %fill_ : [#users=0] = call_function[target=torch.ops.aten.fill_.Tensor](args = (%select, %lift_fresh_copy), kwargs = {})
# return new_zeros
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
fx_g.recompile()
def strip_overloads(gm):
"""
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
Args:
gm(fx.GraphModule): The input Fx graph module to be modified
"""
for node in gm.graph.nodes:
if isinstance(node.target, torch._ops.OpOverload):
node.target = node.target.overloadpacket
gm.recompile()
strip_overloads(fx_g)
ts_g = torch.jit.script(fx_g)
print("ts_g.graph: ")
print(ts_g.graph)
# ts_g.graph:
# graph(%self : __torch__.torch.fx.graph_module._lambda,
# %arg0_1 : Tensor,
# %arg1_1.1 : Tensor):
# %11 : bool = prim::Constant[value=0]() # <eval_with_key>.2:5:144
# %37 : Device = prim::Constant[value="cpu"]()
# %4 : int = prim::Constant[value=1]() # <eval_with_key>.2:5:50
# %5 : int = prim::Constant[value=4]() # <eval_with_key>.2:5:53
# %19 : int = prim::Constant[value=0]() # <eval_with_key>.2:8:49
# %6 : int[] = prim::ListConstruct(%4, %5)
# %new_zeros.1 : Tensor = aten::new_zeros(%arg1_1.1, %6, %5, %19, %37, %11) # <eval_with_key>.2:5:16
# %_tensor_constant0.1 : Tensor = prim::GetAttr[name="_tensor_constant0"](%self)
# %lift_fresh_copy.1 : Tensor = aten::lift_fresh_copy(%_tensor_constant0.1) # <eval_with_key>.2:7:22
# %select.1 : Tensor = aten::select(%new_zeros.1, %4, %19) # <eval_with_key>.2:8:13
# %fill_ : Tensor = aten::fill_(%select.1, %lift_fresh_copy.1) # <eval_with_key>.2:9:12
# return (%new_zeros.1)
module = torch_mlir.compile(
ts_g,
(input_ids, decoder_input_ids),
torch_mlir.OutputType.RAW,
use_tracing=True,
verbose=False,
)
import os
mlir_str = module.operation.get_asm()
dir=tempfile.gettempdir()
with open(os.path.join(dir, "test_masked_fill_torchscript_0327_transformers4.26.0.mlir"), "w") as mlir_file:
mlir_file.write(mlir_str)
@AmosLewis
Copy link
Author

Done

module attributes {torch.debug_module_name = "_lambda"} {
  func.func @forward(%arg0: !torch.vtensor<[1,15],si64>, %arg1: !torch.vtensor<[1,4],si64>) -> !torch.vtensor<[1,4],si64> {
    %int1 = torch.constant.int 1
    %int0 = torch.constant.int 0
    %false = torch.constant.bool false
    %int4 = torch.constant.int 4
    %none = torch.constant.none
    %0 = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
    %cpu = torch.constant.device "cpu"
    %1 = torch.prim.ListConstruct %int1, %int4 : (!torch.int, !torch.int) -> !torch.list<int>
    %2 = torch.aten.zeros %1, %int4, %int0, %cpu, %false : !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.bool -> !torch.vtensor<[1,4],si64>
    %3 = torch.aten.clone %0, %none : !torch.vtensor<[],si64>, !torch.none -> !torch.vtensor<[],si64>
    %4 = torch.aten.slice.Tensor %2, %int1, %int0, %int1, %int1 : !torch.vtensor<[1,4],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,1],si64>
    %5 = torch.aten.squeeze.dim %4, %int1 : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.vtensor<[1],si64>
    %6 = torch.prim.NumToTensor.Scalar %int0 : !torch.int -> !torch.vtensor<[],si64>
    %7 = torch.prim.ListConstruct %6 : (!torch.vtensor<[],si64>) -> !torch.list<optional<vtensor>>
    %8 = torch.aten._index_put_impl %2, %7, %3, %false, %false : !torch.vtensor<[1,4],si64>, !torch.list<optional<vtensor>>, !torch.vtensor<[],si64>, !torch.bool, !torch.bool -> !torch.vtensor<[1,4],si64>
    return %8 : !torch.vtensor<[1,4],si64>
  }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment