-
-
Save zengqingfu1442/402cf0ef6976f07d5e3f56f84be611ee to your computer and use it in GitHub Desktop.
Triton gRPC shared memory client adapted from https://github.com/triton-inference-server/client/blob/main/src/python/examples/simple_grpc_shm_client.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import collections | |
import uuid | |
from typing import Dict | |
import numpy as np | |
import tritonclient | |
import tritonclient.grpc as grpcclient | |
import tritonclient.utils.shared_memory as shm | |
ShmHandle = collections.namedtuple( | |
'ShmHandle', 'shared_data, shm_handle, shm_name, names, byte_sizes' | |
) | |
# https://github.com/triton-inference-server/server/blob/main/docs/model_configuration.md#datatypes | |
DTYPE_np2triton = { | |
np.int64: 'INT64', | |
np.int32: 'INT32', | |
np.int16: 'INT16', | |
np.int8: 'INT8', | |
np.float32: 'FP32', | |
np.float16: 'FP16', | |
np.bool: 'BOOL' | |
} | |
class TritonClient: | |
"""CPU, gRPC only""" | |
def __init__(self, port: int): | |
self.client = grpcclient.InferenceServerClient(url=f'localhost:{port}') | |
def _get_handle(self, d, is_input): | |
[*names], [*data] = zip(*d.items()) | |
byte_sizes = [array.size * array.itemsize for array in data] | |
dtypes = [] | |
if is_input: | |
for array in data: | |
for np_dtype, triton_dtype in DTYPE_np2triton.items(): | |
# cannot use DTYPE_np2triton[array.dtype] | |
if array.dtype == np_dtype: | |
dtypes.append(triton_dtype) | |
break | |
else: | |
raise ValueError(f'Unsupported datatype {array.dtype}') | |
# TODO: is it necessary to use uuid? | |
triton_shm_name = uuid.uuid1().hex | |
shm_key = '/' + triton_shm_name | |
shm_handle = shm.create_shared_memory_region( | |
triton_shm_name, shm_key, byte_size=sum(byte_sizes) | |
) | |
self.client.register_system_shared_memory( | |
triton_shm_name, shm_key, byte_size=sum(byte_sizes) | |
) | |
offset = 0 | |
for i, array in enumerate(data): | |
shm.set_shared_memory_region(shm_handle, [array], offset=offset) | |
offset += byte_sizes[i] | |
shared_data = [] | |
offset = 0 | |
for i, array in enumerate(data): | |
if dtypes: | |
shared_data.append(grpcclient.InferInput(names[i], list(array.shape), dtypes[i])) | |
else: | |
shared_data.append(grpcclient.InferRequestedOutput(names[i])) | |
shared_data[-1].set_shared_memory(triton_shm_name, byte_sizes[i], offset=offset) | |
offset += byte_sizes[i] | |
return ShmHandle( | |
shared_data=shared_data, shm_handle=shm_handle, shm_name=triton_shm_name, | |
names=names, byte_sizes=byte_sizes | |
) | |
def infer( | |
self, | |
model_name: str, | |
input_dict: Dict[str, np.ndarray], | |
output_dict: Dict[str, np.ndarray], | |
model_version: str = '', | |
) -> Dict[str, np.ndarray]: | |
""" | |
Parameters | |
---------- | |
model_name | |
input_dict | |
{input_name: array} | |
output_dict | |
{output_name: array} | |
The array is only used to get the shape and the memory size, so | |
a random array is enough (np.empty). | |
model_version | |
The default value is an empty string which means then the server | |
will choose a version based on the model and internal policy. | |
Returns | |
------- | |
final_output_dict | |
{output_name: array} | |
""" | |
input_handle = self._get_handle(input_dict, is_input=True) | |
output_handle = self._get_handle(output_dict, is_input=False) | |
results = self.client.infer( | |
model_name=model_name, | |
model_version=model_version, | |
inputs=input_handle.shared_data, | |
outputs=output_handle.shared_data | |
) | |
final_output_dict = {} | |
offset = 0 | |
for i, name in enumerate(output_handle.names): | |
value_pb = results.get_output(name) | |
assert value_pb is not None, f'Missing output `{name}`' | |
value = shm.get_contents_as_numpy( | |
output_handle.shm_handle, | |
tritonclient.utils.triton_to_np_dtype(value_pb.datatype), | |
value_pb.shape, offset=offset | |
) | |
offset += output_handle.byte_sizes[i] | |
# will cause error (exit code 139) without copy | |
final_output_dict[name] = np.copy(value) | |
# print(value.dtype, value.shape, value) | |
# clean up | |
self.client.unregister_system_shared_memory(input_handle.shm_name) | |
self.client.unregister_system_shared_memory(output_handle.shm_name) | |
shm.destroy_shared_memory_region(input_handle.shm_handle) | |
shm.destroy_shared_memory_region(output_handle.shm_handle) | |
return final_output_dict |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment