Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created February 14, 2023 07:31
Show Gist options
  • Save pashu123/c0e16e95f347d61ef8004546cef6af4d to your computer and use it in GitHub Desktop.
Save pashu123/c0e16e95f347d61ef8004546cef6af4d to your computer and use it in GitHub Desktop.
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
import tempfile
import torch_mlir
def prepare_sentence_tokens(hf_model: str):
tokenizer = AutoTokenizer.from_pretrained(hf_model)
input_ids = tokenizer(
"Studies have been shown that owning a dog is good for you", return_tensors="pt"
).input_ids # Batch size 1
decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
return input_ids, decoder_input_ids
class HfMaskedLM(torch.nn.Module):
def __init__(self, model_name: str):
super().__init__()
self.model = AutoModelForSeq2SeqLM.from_pretrained(
model_name, # The pretrained model name.
)
self.model.eval()
def forward(self, input_ids, decoder_input_ids):
decoder_input_ids = self.model._shift_right(decoder_input_ids)
return self.model.forward(input_ids=input_ids, decoder_input_ids=decoder_input_ids)[0]
hf_minilm_model = "t5-small"
# test_input = torch.randint(2, (1, 128))
test_input = prepare_sentence_tokens(hf_minilm_model)
model = HfMaskedLM(hf_minilm_model)
tokenizer = AutoTokenizer.from_pretrained("t5-small")
input_ids = tokenizer(
"Studies have been shown that owning a dog is good for you", return_tensors="pt"
).input_ids # Batch size 1
decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
# preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model.
# This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg.
# decoder_input_ids = model._shift_right(decoder_input_ids)
# outputs = model.generate(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
print("model(test_input): ")
print(outputs)
from shark.shark_importer import import_with_fx
inputs = (input_ids, decoder_input_ids)
mlir_module, func_name = import_with_fx(
model, inputs
)
# fx_g = make_fx(
# model,
# decomposition_table=get_decompositions(
# [
# torch.ops.aten.embedding_dense_backward,
# torch.ops.aten.native_layer_norm_backward,
# torch.ops.aten.select_backward,
# torch.ops.aten.norm.ScalarOpt_dim,
# torch.ops.aten.native_group_norm,
# torch.ops.aten.upsample_bilinear2d.vec,
# torch.ops.aten.split.Tensor,
# torch.ops.aten.split_with_sizes,
# torch.ops.aten._to_copy,
# torch.ops.aten._softmax,
# ]
# ),
# )(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
# ➜ t5small git:(main) ✗ python t5small.py
# /home/chi/src/ubuntu20/shark/SHARK/shark.venv/lib/python3.10/site-packages/transformers/models/t5/tokenization_t5_fast.py:156: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.
# For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
# - Be aware that you SHOULD NOT rely on t5-small automatically truncating your input to 512 when padding/encoding.
# - If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.
# - To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value.
# warnings.warn(
# model(test_input):
# tensor([[[-32.0646, -14.4777, -22.1295, ..., -60.9851, -60.9998, -60.8735],
# [-36.0011, -13.1010, -15.8745, ..., -51.4694, -51.4420, -51.5958],
# [-41.3448, -14.2873, -19.3936, ..., -53.9813, -54.0932, -53.9458],
# [-33.9671, -9.6313, -12.7659, ..., -48.1181, -48.1343, -48.0340]]],
# grad_fn=<UnsafeViewBackward0>)
# Traceback (most recent call last):
# File "/home/chi/src/ubuntu20/shark/SHARK/tank/pytorch/t5small/t5small.py", line 55, in <module>
# fx_g = make_fx(
# TypeError: make_fx.<locals>.wrapped() got an unexpected keyword argument 'input_ids'
# print("fx_g.graph: ")
# print(fx_g.graph)
# fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
# fx_g.recompile()
# def strip_overloads(gm):
# """
# Modifies the target of graph nodes in :attr:`gm` to strip overloads.
# Args:
# gm(fx.GraphModule): The input Fx graph module to be modified
# """
# for node in gm.graph.nodes:
# if isinstance(node.target, torch._ops.OpOverload):
# node.target = node.target.overloadpacket
# gm.recompile()
# strip_overloads(fx_g)
# ts_g = torch.jit.script(fx_g)
# print("ts_g.graph: ")
# # print(ts_g.graph)
# module = torch_mlir.compile(
# ts_g,
# (input_ids, decoder_input_ids),
# torch_mlir.OutputType.TOSA,
# use_tracing=True,
# verbose=False,
# )
# # module.dump()
# import os
# mlir_str = module.operation.get_asm()
# dir=tempfile.gettempdir()
# with open(os.path.join(dir, "t5small_torch_tosa_0210_transformers4.21.2.mlir"), "w") as mlir_file:
# mlir_file.write(mlir_str)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment