Created
March 21, 2023 07:43
-
-
Save jarutis/f57a3db7b4c37b59163a2ff5d8c8d54e to your computer and use it in GitHub Desktop.
YoloV8 Torchserve model handler
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Custom TorchServe model handler for YOLOv8 models. | |
""" | |
from ts.torch_handler.base_handler import BaseHandler | |
import numpy as np | |
import base64 | |
import torch | |
import torchvision.transforms as tf | |
import io | |
from PIL import Image | |
import cv2 | |
class ModelHandler(BaseHandler): | |
""" | |
Model handler for YoloV8 bounding box model | |
""" | |
img_size = 640 | |
"""Image size (px). Images will be resized to this resolution before inference. | |
""" | |
def __init__(self): | |
# call superclass initializer | |
super().__init__() | |
def preprocess(self, data): | |
"""Converts input images to float tensors. | |
Args: | |
data (List): Input data from the request in the form of a list of image tensors. | |
Returns: | |
Tensor: single Tensor of shape [BATCH_SIZE, 3, IMG_SIZE, IMG_SIZE] | |
""" | |
images = [] | |
transform = tf.Compose([ | |
tf.ToTensor(), | |
tf.Resize((self.img_size, self.img_size)) | |
]) | |
# handle if images are given in base64, etc. | |
for row in data: | |
# Compat layer: normally the envelope should just return the data | |
# directly, but older versions of Torchserve didn't have envelope. | |
image = row.get("data") or row.get("body") | |
if isinstance(image, str): | |
# if the image is a string of bytesarray. | |
image = base64.b64decode(image) | |
# If the image is sent as bytesarray | |
if isinstance(image, (bytearray, bytes)): | |
image = Image.open(io.BytesIO(image)) | |
else: | |
# if the image is a list | |
image = torch.FloatTensor(image) | |
# force convert to tensor | |
# and resize to [img_size, img_size] | |
image = transform(image) | |
images.append(image) | |
# convert list of equal-size tensors to single stacked tensor | |
# has shape BATCH_SIZE x 3 x IMG_SIZE x IMG_SIZE | |
images_tensor = torch.stack(images).to(self.device) | |
return images_tensor | |
def postprocess(self, inference_output): | |
outputs = np.array([cv2.transpose(inference_output[0].numpy())]) | |
rows = outputs.shape[1] | |
boxes = [] | |
scores = [] | |
class_ids = [] | |
for i in range(rows): | |
classes_scores = outputs[0][i][4:] | |
(minScore, maxScore, minClassLoc, (x, maxClassIndex)) = cv2.minMaxLoc(classes_scores) | |
if maxScore >= 0.25: | |
box = [ | |
outputs[0][i][0] - (0.5 * outputs[0][i][2]), outputs[0][i][1] - (0.5 * outputs[0][i][3]), | |
outputs[0][i][2], outputs[0][i][3]] | |
boxes.append(box) | |
scores.append(maxScore) | |
class_ids.append(maxClassIndex) | |
result_boxes = cv2.dnn.NMSBoxes(boxes, scores, 0.25, 0.45, 0.5) | |
detections = [] | |
for i in range(len(result_boxes)): | |
index = result_boxes[i] | |
box = boxes[index] | |
detection = { | |
'class_id': class_ids[index], | |
'class_name': self.mapping[str(class_ids[index])], | |
'confidence': scores[index], | |
'box': [c.item() for c in box], | |
'scale': self.img_size / 640} | |
print(detection) | |
detections.append(detection) | |
# format each detection | |
return detections |
Looking for a handler for segmentation in exchange :)
Hi, Thank you for providing this! I am trying to use it with YoloV8 and torchserve and for some reason I get the following "number of inputs mismatched". I ve tried changing the handler script but it seems it's not doing anything.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Based on https://github.com/IvanGarcia7/TORCHSERVER and https://github.com/ultralytics/ultralytics/tree/main/examples/YOLOv8-OpenCV-ONNX-Python. Seems to work fine for bounding boxes.