Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created August 10, 2022 08:08
Show Gist options
  • Save pashu123/8e0306cecbf82f6caac4618189854950 to your computer and use it in GitHub Desktop.
Save pashu123/8e0306cecbf82f6caac4618189854950 to your computer and use it in GitHub Desktop.
from PIL import Image
import requests
from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from iree.compiler import tf as tfc
from iree.compiler import compile_str
from iree import runtime as ireert
import os
MAX_SEQUENCE_LENGTH = 512
BATCH_SIZE = 1
class AlbertModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = AutoModelForMaskedLM.from_pretrained("albert-base-v2")
self.model.eval()
def forward(self, input_ids, attention_mask):
return self.model(input_ids=input_ids, attention_mask=input_ids).logits
if __name__ == "__main__":
# Prepping Data
tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
# text = "This is a great [MASK]."
text = "This [MASK] is very tasty."
encoded_inputs = tokenizer(text, padding='max_length', truncation=True, max_length=MAX_SEQUENCE_LENGTH, return_tensors="pt")
x = encoded_inputs["input_ids"]
y = encoded_inputs["attention_mask"]
inputs = (x,y)
# print(f"'>>> {text.replace(tokenizer.mask_token, tokenizer.decode([token]))}'")
mlir_importer = SharkImporter(
AlbertModule(),
inputs,
frontend="torch",
)
minilm_mlir, func_name = mlir_importer.import_mlir(
is_dynamic=False, tracing_required=True
)
shark_module = SharkInference(minilm_mlir, func_name, mlir_dialect="linalg")
shark_module.compile()
output = shark_module.forward(inputs)
out = torch.tensor(output)
print(out)
# print(tokenizer.batch_decode(output, skip_special_tokens=True))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment