Last active
October 24, 2023 11:40
-
-
Save MFreidank/3463d407a94ffe53d0d0daa137ad3973 to your computer and use it in GitHub Desktop.
A https://github.com/pytorch/serve compatible handler for huggingface/transformers sequence classification models.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from abc import ABC | |
import json | |
import logging | |
import os | |
import torch | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
from ts.torch_handler.base_handler import BaseHandler | |
logger = logging.getLogger(__name__) | |
class TransformersClassifierHandler(BaseHandler, ABC): | |
""" | |
Transformers text classifier handler class. This handler takes a text (string) and | |
as input and returns the classification text based on the serialized transformers checkpoint. | |
""" | |
def __init__(self): | |
super(TransformersClassifierHandler, self).__init__() | |
self.initialized = False | |
def initialize(self, ctx): | |
self.manifest = ctx.manifest | |
properties = ctx.system_properties | |
model_dir = properties.get("model_dir") | |
self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu") | |
# Read model serialize/pt file | |
self.model = AutoModelForSequenceClassification.from_pretrained(model_dir) | |
self.tokenizer = AutoTokenizer.from_pretrained(model_dir) | |
self.model.to(self.device) | |
self.model.eval() | |
logger.debug('Transformer model from path {0} loaded successfully'.format(model_dir)) | |
# Read the mapping file, index to object name | |
mapping_file_path = os.path.join(model_dir, "index_to_name.json") | |
if os.path.isfile(mapping_file_path): | |
with open(mapping_file_path) as f: | |
self.mapping = json.load(f) | |
else: | |
logger.warning('Missing the index_to_name.json file. Inference output will not include class name.') | |
self.initialized = True | |
def preprocess(self, data): | |
""" Very basic preprocessing code - only tokenizes. | |
Extend with your own preprocessing steps as needed. | |
""" | |
text = data[0].get("data") | |
if text is None: | |
text = data[0].get("body") | |
sentences = text.decode('utf-8') | |
logger.info("Received text: '%s'", sentences) | |
inputs = self.tokenizer.encode_plus( | |
sentences, | |
add_special_tokens=True, | |
return_tensors="pt" | |
) | |
return inputs | |
def inference(self, inputs): | |
""" | |
Predict the class of a text using a trained transformer model. | |
""" | |
# NOTE: This makes the assumption that your model expects text to be tokenized | |
# with "input_ids" and "token_type_ids" - which is true for some popular transformer models, e.g. bert. | |
# If your transformer model expects different tokenization, adapt this code to suit | |
# its expected input format. | |
prediction = self.model( | |
inputs['input_ids'].to(self.device), | |
token_type_ids=inputs['token_type_ids'].to(self.device) | |
)[0].argmax().item() | |
logger.info("Model predicted: '%s'", prediction) | |
if self.mapping: | |
prediction = self.mapping[str(prediction)] | |
return [prediction] | |
def postprocess(self, inference_output): | |
# TODO: Add any needed post-processing of the model predictions here | |
return inference_output | |
_service = TransformersClassifierHandler() | |
def handle(data, context): | |
try: | |
if not _service.initialized: | |
_service.initialize(context) | |
if data is None: | |
return None | |
data = _service.preprocess(data) | |
data = _service.inference(data) | |
data = _service.postprocess(data) | |
return data | |
except Exception as e: | |
raise e |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment