Created
January 19, 2024 15:27
-
-
Save Nanguage/363c5f8d9dd7b4db6eb9fcb7615554ca to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Test the triton server proxy.""" | |
from PIL import Image | |
import msgpack | |
import numpy as np | |
import requests | |
import gzip | |
import json | |
def get_config(server_url, model_name): | |
response = requests.get( | |
f"{server_url}/public/services/triton-client/get_config?model_name="+model_name, | |
) | |
return json.loads(response.content) | |
def encode_data(inputs): | |
if isinstance(inputs, (np.ndarray, np.generic)): | |
return { | |
"_rtype": "ndarray", | |
"_rvalue": inputs.tobytes(), | |
"_rshape": inputs.shape, | |
"_rdtype": str(inputs.dtype), | |
} | |
elif isinstance(inputs, (tuple, list)): | |
ret = [] | |
for input_data in inputs: | |
ret.append(encode_data(input_data)) | |
return ret | |
elif isinstance(inputs, dict): | |
ret = {} | |
for k in list(inputs.keys()): | |
ret[k] = encode_data(inputs[k]) | |
return ret | |
else: | |
return inputs | |
def decode_data(outputs): | |
if isinstance(outputs, dict): | |
if ( | |
outputs.get("_rtype") == "ndarray" | |
and outputs["_rdtype"] != "object" | |
): | |
return np.frombuffer( | |
outputs["_rvalue"], dtype=outputs["_rdtype"] | |
).reshape(outputs["_rshape"]) | |
else: | |
ret = {} | |
for k in list(outputs.keys()): | |
ret[k] = decode_data(outputs[k]) | |
return ret | |
elif isinstance(outputs, (tuple, list)): | |
ret = [] | |
for output in outputs: | |
ret.append(decode_data(output)) | |
return ret | |
else: | |
return outputs | |
def execute(inputs, server_url, model_name, **kwargs): | |
""" | |
Execute a model on the trition server. | |
The supported kwargs are consistent with pyotritonclient | |
https://github.com/oeway/pyotritonclient/blob/bc655a20fabc4611bbf3c12fb15439c8fc8ee9f5/pyotritonclient/__init__.py#L40-L50 | |
""" | |
# Represent the numpy array with imjoy_rpc encoding | |
# See: https://github.com/imjoy-team/imjoy-rpc#data-type-representation | |
inputs = encode_data(inputs) | |
kwargs.update( | |
{ | |
"inputs": inputs, | |
"model_name": model_name, | |
} | |
) | |
# Encode the arguments as msgpack | |
data = msgpack.dumps(kwargs) | |
# Compress the data and send it via a post request to the server | |
compressed_data = gzip.compress(data) | |
response = requests.post( | |
f"{server_url}/public/services/triton-client/execute", | |
data=compressed_data, | |
headers={ | |
"Content-Type": "application/msgpack", | |
"Content-Encoding": "gzip", | |
}, | |
) | |
if response.ok: | |
# Decode the results form the response | |
results = msgpack.loads(response.content) | |
# Convert the ndarray objects into numpy arrays | |
results = decode_data(results) | |
return results | |
else: | |
raise Exception(f"Failed to execute {model_name}: {response.reason or response.text}") | |
if __name__ == "__main__": | |
server_url = "http://127.0.0.1:9520" | |
# Get the model config with information about inputs/outputs etc. | |
config = get_config(server_url, "efficientsam-encoder") | |
# print(config) | |
# Run inference with cellpose-python model | |
image = np.array(Image.open("tmp/dogs.jpg")) | |
input_image = image.transpose(2, 0, 1)[None].astype(np.float32) / 255.0 | |
results = execute( | |
inputs=[input_image], | |
server_url=server_url, | |
model_name="efficientsam-encoder", | |
decode_json=True, | |
) | |
embeddings = results["image_embeddings"] | |
print(embeddings.shape) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment