Last active
September 9, 2021 15:33
-
-
Save rnyak/d58b40eff9c3c7f37d6a80280af03a94 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# External dependencies | |
import os | |
from time import time | |
import cudf | |
import tritonclient.grpc as grpcclient | |
import nvtabular.inference.triton as nvt_triton | |
#!curl -i triton:8000/v2/health/ready | |
import tritonhttpclient | |
try: | |
triton_client = tritonhttpclient.InferenceServerClient(url="localhost:8000", verbose=True) | |
print("client created.") | |
except Exception as e: | |
print("channel creation failed: " + str(e)) | |
triton_client.is_server_live() | |
NUM_ROWS = 10000 | |
inputs = { | |
'user_session': np.random.randint(1, 10000, NUM_ROWS), | |
'product_id': np.random.randint(1, 51996, NUM_ROWS), | |
'category_id': np.random.randint(0, 332, NUM_ROWS), | |
'event_time_ts': np.random.randint(1570373000, 1670373390, NUM_ROWS), | |
'prod_first_event_time_ts' : np.random.randint(1570373000, 1570373382, NUM_ROWS), | |
'price' : np.random.uniform(0, 2750, NUM_ROWS) | |
} | |
df = cudf.DataFrame(inputs) | |
batch = df[['category_id', | |
'event_time_ts', | |
'user_session', | |
'prod_first_event_time_ts', | |
'price', | |
'product_id']].iloc[:3, :] | |
print(batch) | |
inputs = nvt_triton.convert_df_to_triton_input(batch.columns, batch, grpcclient.InferInput) | |
output_names_org = ['user_session', | |
'product_id-count', | |
'product_id-list_seq', | |
'et_dayofweek_sin-list_seq', | |
'price_log_norm-list_seq', | |
'relative_price_to_avg_categ_id-list_seq', | |
'category_id-list_seq', | |
'product_recency_days_log_norm-list_seq', | |
'et_dayofweek_cos-list_seq', | |
'day_index'] | |
output_names = [] | |
for name in output_names_org: | |
if '-list' in name: | |
output_names.append(name + '__values') | |
output_names.append(name + '__nnzs') | |
else: | |
output_names.append(name) | |
outputs = [] | |
MODEL_NAME_NVT = "model_nvt" | |
for col in output_names: | |
outputs.append(grpcclient.InferRequestedOutput(col)) | |
with grpcclient.InferenceServerClient("localhost:8001") as client: | |
response = client.infer(MODEL_NAME_NVT, inputs, request_id="1", outputs=outputs) | |
for col in output_names: | |
print(col, response.as_numpy(col), response.as_numpy(col).shape) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment