Created
September 14, 2021 11:57
-
-
Save PhanDuc/55c9ccc11375fece8b8d84adffbe3591 to your computer and use it in GitHub Desktop.
simple test batch triton
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ast | |
import argparse | |
import cv2 as cv | |
# import json | |
import numpy as np | |
import os | |
# import tensorflow as tf | |
import time | |
import uuid | |
import tritonclient.grpc as grpcclient | |
# import tritonclient.http as httpclient | |
import tritonclient.utils.shared_memory as shm | |
from tritonclient.utils import * | |
parser = argparse.ArgumentParser(description="Triton ensemble model FLAG-NSFW-HORROR_onnx") | |
parser.add_argument("--model_name", help="Model name", default="ensemble_prp_flag_nsfw_horror") | |
parser.add_argument("--img_path", help="test image path or test directory path", default="/home/ducpv/gitCode/triton_handons/triton_inference/clients/python/data/flag.jpg") | |
parser.add_argument("--test_mode", help="test data option\n-image/directory path(ip)\n-list image path(lip)\n-batch numpy array(np)\n-list numpy array(lnp) ", default="lip") | |
args = parser.parse_args() | |
test_mode = args.test_mode | |
# numpy array | |
if test_mode == "np": | |
path = np.array(np.random.uniform(0, 255, size=(640, 640, 3)), dtype=np.int8) | |
path = np.array(np.random.uniform(0, 255, size=(3, 640, 640, 3)), dtype=np.int8) | |
print(path.shape) | |
elif test_mode == "lnp": | |
path = [ | |
np.array(np.random.uniform(0, 255, size=(640, 640, 3)), dtype=np.int8), | |
np.array(np.random.uniform(0, 255, size=(480, 480, 3)), dtype=np.int8), | |
np.array(np.random.uniform(0, 255, size=(645, 780, 3)), dtype=np.int8) | |
] | |
print(path[0].shape, path[1].shape) | |
# image path or directory | |
elif test_mode == "ip": | |
path = args.img_path | |
# list image path | |
elif test_mode == "lip": | |
# path = ["test_img/test_img_0.jpg", "test_img/test_img_1.jpg", "test_img/Viet_Nam.jpg", "test_img/titan.jpg"] | |
path = [args.img_path] * 2 | |
model_name = args.model_name | |
#model_name = "yolov5_horror_preprocessing" | |
#model_name = "yolov5_nsfw_onnx" | |
def read_cv_img(list_img_path, input_npdtype): | |
list_img = [] | |
ori_image_shape = [] | |
ori_image_name = [] | |
for p in list_img_path: | |
try: | |
img = cv.imread(p) | |
# If use tf preprocess on server, not to OpenCV | |
#img = cv.cvtColor(img, cv.COLOR_BGR2RGB) | |
ori_image_shape.append(img.shape[:2]) | |
ori_image_name.append(os.path.basename(p)) | |
#img = cv.resize(img, (640, 640)) | |
if p.endswith(".jpg"): | |
tail = ".jpg" | |
elif p.endswith(".jpeg"): | |
tail = ".jpeg" | |
elif p.endswith(".png"): | |
tail = ".png" | |
img_to_bytes = cv.imencode(tail, img)[1].tobytes() | |
list_img.append(img_to_bytes) | |
except Exception as e: | |
print("load img fail: ", p, e) | |
array = np.array([list_img], dtype=np.object_).reshape((len(list_img), 1)) | |
# shape_data = np.array(ori_image_shape, input_npdtype[1]) | |
# name_data = np.array(ori_image_name, input_npdtype[2]).reshape((len(list_img), 1)) | |
#print(array.shape) | |
#print(shape_data.shape) | |
#print(name_data.shape) | |
return array | |
with grpcclient.InferenceServerClient("localhost:6001") as client: | |
model_metadata = client.get_model_metadata(model_name=model_name) | |
#print("--> model_metadata:", model_metadata) | |
#dtype = model_metadata["inputs"][0]["datatype"] | |
input_name = [i.name for i in model_metadata.inputs] | |
output_names = [i.name for i in model_metadata.outputs] | |
input_dtype = [i.datatype for i in model_metadata.inputs] | |
input_npdtype = [triton_to_np_dtype(dtype) for dtype in input_dtype] | |
#print(input_npdtype) | |
#print(input_dtype) | |
#print(model_metadata) | |
#print(input_name) | |
if isinstance(path, str): | |
if os.path.isdir(path): | |
img_path = [os.path.join(path, f) for f in os.listdir(path) if os.path.isfile(os.path.join(path, f))] | |
elif os.path.isfile(path): | |
img_path = [path] | |
# array, shape_data, name_data = read_cv_img(img_path, input_npdtype) | |
array = read_cv_img(img_path, input_npdtype) | |
elif isinstance(path, np.ndarray): | |
if len(path.shape) == 3: | |
n_batch = 1 | |
array = np.array([cv.imencode(".jpg", path)[1].tobytes()], dtype=np.object_).reshape((n_batch, 1)) | |
shape_data = np.array([path.shape[:2]], dtype=input_npdtype[1]) | |
ori_image_name = [uuid.uuid4()] | |
name_data = np.array([ori_image_name], dtype=input_npdtype[2]).reshape((n_batch, 1)) | |
elif len(path.shape) == 4: | |
n_batch = path.shape[0] | |
list_bytes_img = [cv.imencode(".jpg", img)[1].tobytes() for img in path] | |
ori_image_shape = [path.shape[1:3] for i in range(n_batch)] | |
ori_image_name = [uuid.uuid4() for i in range(n_batch)] | |
array = np.array(list_bytes_img, dtype=np.object_).reshape((n_batch, 1)) | |
shape_data = np.array(ori_image_shape, dtype=input_npdtype[1]) | |
name_data = np.array([ori_image_name], dtype=input_npdtype[2]).reshape((n_batch, 1)) | |
elif isinstance(path, list): | |
""" | |
list: | |
\t- list of full image path | |
\t- list of full numpy array | |
""" | |
if isinstance(path[0], str): | |
# array, shape_data, name_data = read_cv_img(path, input_npdtype) | |
array = read_cv_img(path, input_npdtype) | |
elif isinstance(path[0], np.ndarray): | |
list_bytes_img = [cv.imencode(".jpg", img)[1].tobytes() for img in path] | |
ori_image_shape = [img.shape[:2] for img in path] | |
ori_image_name = [uuid.uuid4() for i in range(len(path))] | |
array = np.array(list_bytes_img, dtype=np.object_).reshape((len(path), 1)) | |
shape_data = np.array(ori_image_shape, dtype=input_npdtype[1]) | |
name_data = np.array(ori_image_name, dtype=input_npdtype[2]).reshape((len(path), 1)) | |
# batched_data = [array, shape_data, name_data] | |
# batched_data = array | |
# inputs = [] | |
# for i, input in enumerate(input_name): | |
# input_data = grpcclient.InferInput(input, batched_data.shape, input_dtype[i]) | |
# input_data.set_data_from_numpy(batched_data) | |
# inputs.append(input_data) | |
input = grpcclient.InferInput("INPUT", array.shape, "BYTES") | |
input.set_data_from_numpy(array) | |
inputs = [input] | |
#print(output_names) | |
outputs = [] | |
for output in output_names: | |
outputs.append(grpcclient.InferRequestedOutput(output)) | |
print("Start triton inference...") | |
# Preprocessing on server | |
t0 = time.time() | |
response = client.infer(model_name, inputs, outputs=outputs) | |
triton_t = time.time() - t0 | |
#print("--> preprocess inference with triton python backend time:", preprocess_triton_t) | |
print("--> Inference time: ", triton_t) | |
print(f"--> Throughput: {array.shape[0]/triton_t:.2f} img/sec") | |
print(response.as_numpy("FLAG_OUTPUT").shape) # x(1, 25500, 8) | |
print(response.as_numpy("NSFW_OUTPUT").shape) # x(1, 25200, 9) | |
# result = response.get_response() | |
# ensemble_result_str = response.as_numpy("ENSEMBLE_RESULT")[0].decode("utf-8")#[1:-1] | |
# ensemble_result = ast.literal_eval(ensemble_result_str) | |
# print(ensemble_result) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment