Created
February 14, 2023 07:31
-
-
Save pashu123/c0e16e95f347d61ef8004546cef6af4d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import torch | |
from torch.fx.experimental.proxy_tensor import make_fx | |
from torch._decomp import get_decompositions | |
import tempfile | |
import torch_mlir | |
def prepare_sentence_tokens(hf_model: str): | |
tokenizer = AutoTokenizer.from_pretrained(hf_model) | |
input_ids = tokenizer( | |
"Studies have been shown that owning a dog is good for you", return_tensors="pt" | |
).input_ids # Batch size 1 | |
decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 | |
return input_ids, decoder_input_ids | |
class HfMaskedLM(torch.nn.Module): | |
def __init__(self, model_name: str): | |
super().__init__() | |
self.model = AutoModelForSeq2SeqLM.from_pretrained( | |
model_name, # The pretrained model name. | |
) | |
self.model.eval() | |
def forward(self, input_ids, decoder_input_ids): | |
decoder_input_ids = self.model._shift_right(decoder_input_ids) | |
return self.model.forward(input_ids=input_ids, decoder_input_ids=decoder_input_ids)[0] | |
hf_minilm_model = "t5-small" | |
# test_input = torch.randint(2, (1, 128)) | |
test_input = prepare_sentence_tokens(hf_minilm_model) | |
model = HfMaskedLM(hf_minilm_model) | |
tokenizer = AutoTokenizer.from_pretrained("t5-small") | |
input_ids = tokenizer( | |
"Studies have been shown that owning a dog is good for you", return_tensors="pt" | |
).input_ids # Batch size 1 | |
decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 | |
# preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model. | |
# This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg. | |
# decoder_input_ids = model._shift_right(decoder_input_ids) | |
# outputs = model.generate(input_ids=input_ids, decoder_input_ids=decoder_input_ids) | |
outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) | |
print("model(test_input): ") | |
print(outputs) | |
from shark.shark_importer import import_with_fx | |
inputs = (input_ids, decoder_input_ids) | |
mlir_module, func_name = import_with_fx( | |
model, inputs | |
) | |
# fx_g = make_fx( | |
# model, | |
# decomposition_table=get_decompositions( | |
# [ | |
# torch.ops.aten.embedding_dense_backward, | |
# torch.ops.aten.native_layer_norm_backward, | |
# torch.ops.aten.select_backward, | |
# torch.ops.aten.norm.ScalarOpt_dim, | |
# torch.ops.aten.native_group_norm, | |
# torch.ops.aten.upsample_bilinear2d.vec, | |
# torch.ops.aten.split.Tensor, | |
# torch.ops.aten.split_with_sizes, | |
# torch.ops.aten._to_copy, | |
# torch.ops.aten._softmax, | |
# ] | |
# ), | |
# )(input_ids=input_ids, decoder_input_ids=decoder_input_ids) | |
# ➜ t5small git:(main) ✗ python t5small.py | |
# /home/chi/src/ubuntu20/shark/SHARK/shark.venv/lib/python3.10/site-packages/transformers/models/t5/tokenization_t5_fast.py:156: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5. | |
# For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`. | |
# - Be aware that you SHOULD NOT rely on t5-small automatically truncating your input to 512 when padding/encoding. | |
# - If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding. | |
# - To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value. | |
# warnings.warn( | |
# model(test_input): | |
# tensor([[[-32.0646, -14.4777, -22.1295, ..., -60.9851, -60.9998, -60.8735], | |
# [-36.0011, -13.1010, -15.8745, ..., -51.4694, -51.4420, -51.5958], | |
# [-41.3448, -14.2873, -19.3936, ..., -53.9813, -54.0932, -53.9458], | |
# [-33.9671, -9.6313, -12.7659, ..., -48.1181, -48.1343, -48.0340]]], | |
# grad_fn=<UnsafeViewBackward0>) | |
# Traceback (most recent call last): | |
# File "/home/chi/src/ubuntu20/shark/SHARK/tank/pytorch/t5small/t5small.py", line 55, in <module> | |
# fx_g = make_fx( | |
# TypeError: make_fx.<locals>.wrapped() got an unexpected keyword argument 'input_ids' | |
# print("fx_g.graph: ") | |
# print(fx_g.graph) | |
# fx_g.graph.set_codegen(torch.fx.graph.CodeGen()) | |
# fx_g.recompile() | |
# def strip_overloads(gm): | |
# """ | |
# Modifies the target of graph nodes in :attr:`gm` to strip overloads. | |
# Args: | |
# gm(fx.GraphModule): The input Fx graph module to be modified | |
# """ | |
# for node in gm.graph.nodes: | |
# if isinstance(node.target, torch._ops.OpOverload): | |
# node.target = node.target.overloadpacket | |
# gm.recompile() | |
# strip_overloads(fx_g) | |
# ts_g = torch.jit.script(fx_g) | |
# print("ts_g.graph: ") | |
# # print(ts_g.graph) | |
# module = torch_mlir.compile( | |
# ts_g, | |
# (input_ids, decoder_input_ids), | |
# torch_mlir.OutputType.TOSA, | |
# use_tracing=True, | |
# verbose=False, | |
# ) | |
# # module.dump() | |
# import os | |
# mlir_str = module.operation.get_asm() | |
# dir=tempfile.gettempdir() | |
# with open(os.path.join(dir, "t5small_torch_tosa_0210_transformers4.21.2.mlir"), "w") as mlir_file: | |
# mlir_file.write(mlir_str) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment