Created
March 7, 2023 00:27
-
-
Save AmosLewis/97ac51bd69c1e3447e9c29884da7aaf0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# pip install transformers==4.26.0 | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import torch | |
from torch.fx.experimental.proxy_tensor import make_fx | |
from torch._decomp import get_decompositions | |
import tempfile | |
import torch_mlir | |
class HfMaskedLM(torch.nn.Module): | |
def __init__(self, model_name: str): | |
super().__init__() | |
self.model = AutoModelForSeq2SeqLM.from_pretrained( | |
model_name, # The pretrained model name. | |
) | |
self.model.eval() | |
def forward(self, input_ids, decoder_input_ids): | |
# preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model. | |
# This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg. | |
decoder_input_ids = self.model._shift_right(decoder_input_ids) | |
return self.model.forward(input_ids=input_ids, decoder_input_ids=decoder_input_ids)[0] | |
hf_model_name = "t5-large" | |
model = HfMaskedLM(hf_model_name) | |
tokenizer = AutoTokenizer.from_pretrained(hf_model_name) | |
input_ids = tokenizer( | |
"Studies have been shown that owning a dog is good for you", return_tensors="pt" | |
).input_ids # Batch size 1 | |
decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 | |
test_inputs = (input_ids, decoder_input_ids) | |
# outputs = model.generate(input_ids=input_ids, decoder_input_ids=decoder_input_ids) | |
outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) | |
print("model(test_input): ") | |
print(outputs) | |
fx_g = make_fx( | |
model, | |
decomposition_table=get_decompositions( | |
[ | |
torch.ops.aten.embedding_dense_backward, | |
torch.ops.aten.native_layer_norm_backward, | |
torch.ops.aten.select_backward, | |
torch.ops.aten.norm.ScalarOpt_dim, | |
torch.ops.aten.native_group_norm, | |
torch.ops.aten.upsample_bilinear2d.vec, | |
torch.ops.aten.split.Tensor, | |
torch.ops.aten.split_with_sizes, | |
torch.ops.aten._to_copy, | |
torch.ops.aten._softmax, | |
] | |
), | |
)(*test_inputs) | |
print("fx_g.graph: ") | |
# print(fx_g.graph) | |
fx_g.graph.set_codegen(torch.fx.graph.CodeGen()) | |
fx_g.recompile() | |
def strip_overloads(gm): | |
""" | |
Modifies the target of graph nodes in :attr:`gm` to strip overloads. | |
Args: | |
gm(fx.GraphModule): The input Fx graph module to be modified | |
""" | |
for node in gm.graph.nodes: | |
if isinstance(node.target, torch._ops.OpOverload): | |
node.target = node.target.overloadpacket | |
gm.recompile() | |
strip_overloads(fx_g) | |
ts_g = torch.jit.script(fx_g) | |
# print("ts_g.graph: ") | |
# print(ts_g.graph) | |
module = torch_mlir.compile( | |
ts_g, | |
(input_ids, decoder_input_ids), | |
torch_mlir.OutputType.RAW, | |
use_tracing=True, | |
verbose=False, | |
) | |
# module.dump() | |
import os | |
mlir_str = module.operation.get_asm() | |
dir=tempfile.gettempdir() | |
with open(os.path.join(dir, "t5large_torchscript_0306_transformers4.26.0.mlir"), "w") as mlir_file: | |
mlir_file.write(mlir_str) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment