Last active
April 21, 2023 16:59
-
-
Save AmosLewis/40a5828d61e8776000e10ba28df1977e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# pip install transformers==4.26.0 | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import torch | |
from torch.fx.experimental.proxy_tensor import make_fx | |
from torch._decomp import get_decompositions | |
import tempfile | |
import torch_mlir | |
class HfMaskedLM(torch.nn.Module): | |
def __init__(self, model_name: str): | |
super().__init__() | |
self.model = AutoModelForSeq2SeqLM.from_pretrained( | |
model_name, # The pretrained model name. | |
) | |
self.model.eval() | |
def forward(self, input_ids, decoder_input_ids): | |
# preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model. | |
# This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg. | |
decoder_input_ids = self.model._shift_right(decoder_input_ids) | |
return self.model.forward(input_ids=input_ids, decoder_input_ids=decoder_input_ids)[0] | |
hf_model_name = "t5-small" | |
model = HfMaskedLM(hf_model_name) | |
tokenizer = AutoTokenizer.from_pretrained(hf_model_name) | |
input_ids = tokenizer( | |
"Studies have been shown that owning a dog is good for you", return_tensors="pt" | |
).input_ids # Batch size 1 | |
decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 | |
test_inputs = (input_ids, decoder_input_ids) | |
# outputs = model.generate(input_ids=input_ids, decoder_input_ids=decoder_input_ids) | |
outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) | |
print("model(test_input): ") | |
print(outputs) | |
fx_g = make_fx( | |
model, | |
decomposition_table=get_decompositions( | |
[ | |
torch.ops.aten.embedding_dense_backward, | |
torch.ops.aten.native_layer_norm_backward, | |
torch.ops.aten.select_backward, | |
torch.ops.aten.norm.ScalarOpt_dim, | |
torch.ops.aten.native_group_norm, | |
torch.ops.aten.upsample_bilinear2d.vec, | |
torch.ops.aten.split.Tensor, | |
torch.ops.aten.split_with_sizes, | |
torch.ops.aten._to_copy, | |
torch.ops.aten._softmax, | |
] | |
), | |
)(*test_inputs) | |
# print("fx_g.graph: ") | |
# print(fx_g.graph) | |
fx_g.graph.set_codegen(torch.fx.graph.CodeGen()) | |
fx_g.recompile() | |
def strip_overloads(gm): | |
""" | |
Modifies the target of graph nodes in :attr:`gm` to strip overloads. | |
Args: | |
gm(fx.GraphModule): The input Fx graph module to be modified | |
""" | |
for node in gm.graph.nodes: | |
if isinstance(node.target, torch._ops.OpOverload): | |
node.target = node.target.overloadpacket | |
gm.recompile() | |
strip_overloads(fx_g) | |
ts_g = torch.jit.script(fx_g) | |
# print("ts_g.graph: ") | |
# print(ts_g.graph) | |
module = torch_mlir.compile( | |
ts_g, | |
(input_ids, decoder_input_ids), | |
torch_mlir.OutputType.TOSA, | |
use_tracing=True, | |
verbose=False, | |
) | |
# module.dump() | |
import os | |
mlir_str = module.operation.get_asm() | |
dir=tempfile.gettempdir() | |
with open(os.path.join(dir, "t5small_tosa_0420_transformers4.26.0.mlir"), "w") as mlir_file: | |
mlir_file.write(mlir_str) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment