Last active
January 9, 2019 16:09
-
-
Save salman-ghauri/88ea842e356ab3c7207e601836c3e3c4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""A simple GRPC client to communicate with tf-server on the given port. | |
It has been developed by some help from another GIST which I forgot to reference.""" | |
import time | |
from argparse import ArgumentParser | |
from grpc.beta import implementations | |
from tensorflow_serving.apis import predict_pb2 | |
from tensorflow_serving.apis import prediction_service_pb2 | |
from tensorflow.contrib.util import make_tensor_proto | |
import grpc | |
import scipy | |
def parse_args(): | |
parser = ArgumentParser(description='Request a TensorFlow server for a prediction on the image') | |
parser.add_argument('-s', '--server', | |
dest='server', | |
default='172.17.0.2:8500', | |
help='prediction service host:port') | |
parser.add_argument('-i', '--image_path', | |
dest='image_path', | |
default='', | |
help='path to images folder',) | |
parser.add_argument('-b', '--batch_mode', | |
dest='batch_mode', | |
default='true', | |
help='send image as batch or one-by-one') | |
args = parser.parse_args() | |
host, port = args.server.split(':') | |
return host, port, args.image_path, args.batch_mode == 'true' | |
def main(): | |
# parse command line arguments | |
host, port, image_path, batch_mode = parse_args() | |
channel = grpc.insecure_channel(host+':'+port) | |
stub = prediction_service_pb2.PredictionServiceStub(channel) | |
start = time.time() | |
request = predict_pb2.PredictRequest() | |
request.model_spec.name = "custom_detector" | |
request.model_spec.signature_name = "detection_signature" | |
img = scipy.misc.imread(image_path) | |
request.inputs['inputs'].CopyFrom(make_tensor_proto(img, shape=[1] + list(img.shape))) | |
result = stub.Predict(request, 30.0) | |
end = time.time() | |
time_diff = end - start | |
print('time elapased: {}'.format(time_diff)) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment