Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Last active April 21, 2023 16:59
Show Gist options
  • Save AmosLewis/40a5828d61e8776000e10ba28df1977e to your computer and use it in GitHub Desktop.
Save AmosLewis/40a5828d61e8776000e10ba28df1977e to your computer and use it in GitHub Desktop.
# pip install transformers==4.26.0
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
import tempfile
import torch_mlir
class HfMaskedLM(torch.nn.Module):
def __init__(self, model_name: str):
super().__init__()
self.model = AutoModelForSeq2SeqLM.from_pretrained(
model_name, # The pretrained model name.
)
self.model.eval()
def forward(self, input_ids, decoder_input_ids):
# preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model.
# This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg.
decoder_input_ids = self.model._shift_right(decoder_input_ids)
return self.model.forward(input_ids=input_ids, decoder_input_ids=decoder_input_ids)[0]
hf_model_name = "t5-small"
model = HfMaskedLM(hf_model_name)
tokenizer = AutoTokenizer.from_pretrained(hf_model_name)
input_ids = tokenizer(
"Studies have been shown that owning a dog is good for you", return_tensors="pt"
).input_ids # Batch size 1
decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
test_inputs = (input_ids, decoder_input_ids)
# outputs = model.generate(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
print("model(test_input): ")
print(outputs)
fx_g = make_fx(
model,
decomposition_table=get_decompositions(
[
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
torch.ops.aten._to_copy,
torch.ops.aten._softmax,
]
),
)(*test_inputs)
# print("fx_g.graph: ")
# print(fx_g.graph)
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
fx_g.recompile()
def strip_overloads(gm):
"""
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
Args:
gm(fx.GraphModule): The input Fx graph module to be modified
"""
for node in gm.graph.nodes:
if isinstance(node.target, torch._ops.OpOverload):
node.target = node.target.overloadpacket
gm.recompile()
strip_overloads(fx_g)
ts_g = torch.jit.script(fx_g)
# print("ts_g.graph: ")
# print(ts_g.graph)
module = torch_mlir.compile(
ts_g,
(input_ids, decoder_input_ids),
torch_mlir.OutputType.TOSA,
use_tracing=True,
verbose=False,
)
# module.dump()
import os
mlir_str = module.operation.get_asm()
dir=tempfile.gettempdir()
with open(os.path.join(dir, "t5small_tosa_0420_transformers4.26.0.mlir"), "w") as mlir_file:
mlir_file.write(mlir_str)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment