Created September 14, 2021 11:57
simple test batch triton
import ast
import argparse
import cv2 as cv
# import json
import numpy as np
import os
# import tensorflow as tf
import time
import uuid
import tritonclient.grpc as grpcclient
# import tritonclient.http as httpclient
import tritonclient.utils.shared_memory as shm
from tritonclient.utils import *
parser = argparse.ArgumentParser(description="Triton ensemble model FLAG-NSFW-HORROR_onnx")
parser.add_argument("--model_name", help="Model name", default="ensemble_prp_flag_nsfw_horror")
parser.add_argument("--img_path", help="test image path or test directory path", default="/home/ducpv/gitCode/triton_handons/triton_inference/clients/python/data/flag.jpg")
parser.add_argument("--test_mode", help="test data option\n-image/directory path(ip)\n-list image path(lip)\n-batch numpy array(np)\n-list numpy array(lnp) ", default="lip")
args = parser.parse_args()
test_mode = args.test_mode
# numpy array
if test_mode == "np":
path = np.array(np.random.uniform(0, 255, size=(640, 640, 3)), dtype=np.int8)
path = np.array(np.random.uniform(0, 255, size=(3, 640, 640, 3)), dtype=np.int8)
elif test_mode == "lnp":
path = [
np.array(np.random.uniform(0, 255, size=(640, 640, 3)), dtype=np.int8),
np.array(np.random.uniform(0, 255, size=(480, 480, 3)), dtype=np.int8),
np.array(np.random.uniform(0, 255, size=(645, 780, 3)), dtype=np.int8)
print(path[0].shape, path[1].shape)
# image path or directory
elif test_mode == "ip":
path = args.img_path
# list image path
elif test_mode == "lip":
# path = ["test_img/test_img_0.jpg", "test_img/test_img_1.jpg", "test_img/Viet_Nam.jpg", "test_img/titan.jpg"]
path = [args.img_path] * 2
model_name = args.model_name
#model_name = "yolov5_horror_preprocessing"
#model_name = "yolov5_nsfw_onnx"
def read_cv_img(list_img_path, input_npdtype):
list_img = []
ori_image_shape = []
ori_image_name = []
for p in list_img_path:
img = cv.imread(p)
# If use tf preprocess on server, not to OpenCV
#img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
#img = cv.resize(img, (640, 640))
if p.endswith(".jpg"):
tail = ".jpg"
elif p.endswith(".jpeg"):
tail = ".jpeg"
elif p.endswith(".png"):
tail = ".png"
img_to_bytes = cv.imencode(tail, img)[1].tobytes()
except Exception as e:
print("load img fail: ", p, e)
array = np.array([list_img], dtype=np.object_).reshape((len(list_img), 1))
# shape_data = np.array(ori_image_shape, input_npdtype[1])
# name_data = np.array(ori_image_name, input_npdtype[2]).reshape((len(list_img), 1))
return array
with grpcclient.InferenceServerClient("localhost:6001") as client:
model_metadata = client.get_model_metadata(model_name=model_name)
#print("--> model_metadata:", model_metadata)
#dtype = model_metadata["inputs"][0]["datatype"]
input_name = [ for i in model_metadata.inputs]
output_names = [ for i in model_metadata.outputs]
input_dtype = [i.datatype for i in model_metadata.inputs]
input_npdtype = [triton_to_np_dtype(dtype) for dtype in input_dtype]
if isinstance(path, str):
if os.path.isdir(path):
img_path = [os.path.join(path, f) for f in os.listdir(path) if os.path.isfile(os.path.join(path, f))]
elif os.path.isfile(path):
img_path = [path]
# array, shape_data, name_data = read_cv_img(img_path, input_npdtype)
array = read_cv_img(img_path, input_npdtype)
elif isinstance(path, np.ndarray):
if len(path.shape) == 3:
n_batch = 1
array = np.array([cv.imencode(".jpg", path)[1].tobytes()], dtype=np.object_).reshape((n_batch, 1))
shape_data = np.array([path.shape[:2]], dtype=input_npdtype[1])
ori_image_name = [uuid.uuid4()]
name_data = np.array([ori_image_name], dtype=input_npdtype[2]).reshape((n_batch, 1))
elif len(path.shape) == 4:
n_batch = path.shape[0]
list_bytes_img = [cv.imencode(".jpg", img)[1].tobytes() for img in path]
ori_image_shape = [path.shape[1:3] for i in range(n_batch)]
ori_image_name = [uuid.uuid4() for i in range(n_batch)]
array = np.array(list_bytes_img, dtype=np.object_).reshape((n_batch, 1))
shape_data = np.array(ori_image_shape, dtype=input_npdtype[1])
name_data = np.array([ori_image_name], dtype=input_npdtype[2]).reshape((n_batch, 1))
elif isinstance(path, list):
\t- list of full image path
\t- list of full numpy array
if isinstance(path[0], str):
# array, shape_data, name_data = read_cv_img(path, input_npdtype)
array = read_cv_img(path, input_npdtype)
elif isinstance(path[0], np.ndarray):
list_bytes_img = [cv.imencode(".jpg", img)[1].tobytes() for img in path]
ori_image_shape = [img.shape[:2] for img in path]
ori_image_name = [uuid.uuid4() for i in range(len(path))]
array = np.array(list_bytes_img, dtype=np.object_).reshape((len(path), 1))
shape_data = np.array(ori_image_shape, dtype=input_npdtype[1])
name_data = np.array(ori_image_name, dtype=input_npdtype[2]).reshape((len(path), 1))
# batched_data = [array, shape_data, name_data]
# batched_data = array
# inputs = []
# for i, input in enumerate(input_name):
# input_data = grpcclient.InferInput(input, batched_data.shape, input_dtype[i])
# input_data.set_data_from_numpy(batched_data)
# inputs.append(input_data)
input = grpcclient.InferInput("INPUT", array.shape, "BYTES")
inputs = [input]
outputs = []
for output in output_names:
print("Start triton inference...")
# Preprocessing on server
t0 = time.time()
response = client.infer(model_name, inputs, outputs=outputs)
triton_t = time.time() - t0
#print("--> preprocess inference with triton python backend time:", preprocess_triton_t)
print("--> Inference time: ", triton_t)
print(f"--> Throughput: {array.shape[0]/triton_t:.2f} img/sec")
print(response.as_numpy("FLAG_OUTPUT").shape) # x(1, 25500, 8)
print(response.as_numpy("NSFW_OUTPUT").shape) # x(1, 25200, 9)
# result = response.get_response()
# ensemble_result_str = response.as_numpy("ENSEMBLE_RESULT")[0].decode("utf-8")#[1:-1]
# ensemble_result = ast.literal_eval(ensemble_result_str)
# print(ensemble_result)
