Skip to content

Instantly share code, notes, and snippets.

@rnyak
Last active September 9, 2021 15:33
Show Gist options
  • Save rnyak/d58b40eff9c3c7f37d6a80280af03a94 to your computer and use it in GitHub Desktop.
Save rnyak/d58b40eff9c3c7f37d6a80280af03a94 to your computer and use it in GitHub Desktop.
# External dependencies
import os
from time import time
import cudf
import tritonclient.grpc as grpcclient
import nvtabular.inference.triton as nvt_triton
#!curl -i triton:8000/v2/health/ready
import tritonhttpclient
try:
triton_client = tritonhttpclient.InferenceServerClient(url="localhost:8000", verbose=True)
print("client created.")
except Exception as e:
print("channel creation failed: " + str(e))
triton_client.is_server_live()
NUM_ROWS = 10000
inputs = {
'user_session': np.random.randint(1, 10000, NUM_ROWS),
'product_id': np.random.randint(1, 51996, NUM_ROWS),
'category_id': np.random.randint(0, 332, NUM_ROWS),
'event_time_ts': np.random.randint(1570373000, 1670373390, NUM_ROWS),
'prod_first_event_time_ts' : np.random.randint(1570373000, 1570373382, NUM_ROWS),
'price' : np.random.uniform(0, 2750, NUM_ROWS)
}
df = cudf.DataFrame(inputs)
batch = df[['category_id',
'event_time_ts',
'user_session',
'prod_first_event_time_ts',
'price',
'product_id']].iloc[:3, :]
print(batch)
inputs = nvt_triton.convert_df_to_triton_input(batch.columns, batch, grpcclient.InferInput)
output_names_org = ['user_session',
'product_id-count',
'product_id-list_seq',
'et_dayofweek_sin-list_seq',
'price_log_norm-list_seq',
'relative_price_to_avg_categ_id-list_seq',
'category_id-list_seq',
'product_recency_days_log_norm-list_seq',
'et_dayofweek_cos-list_seq',
'day_index']
output_names = []
for name in output_names_org:
if '-list' in name:
output_names.append(name + '__values')
output_names.append(name + '__nnzs')
else:
output_names.append(name)
outputs = []
MODEL_NAME_NVT = "model_nvt"
for col in output_names:
outputs.append(grpcclient.InferRequestedOutput(col))
with grpcclient.InferenceServerClient("localhost:8001") as client:
response = client.infer(MODEL_NAME_NVT, inputs, request_id="1", outputs=outputs)
for col in output_names:
print(col, response.as_numpy(col), response.as_numpy(col).shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment