Skip to content

Instantly share code, notes, and snippets.

@nqbao
Last active February 27, 2017 22:46
Show Gist options
  • Save nqbao/6764a505cc3dc2be1b23bc587d37bf87 to your computer and use it in GitHub Desktop.
Save nqbao/6764a505cc3dc2be1b23bc587d37bf87 to your computer and use it in GitHub Desktop.
TF gRPC Benchmark
import grpc.beta.implementations
import tensorflow_serving.apis.predict_pb2 as predict_pb2
import tensorflow_serving.apis.prediction_service_pb2 as prediction_service_pb2
import numpy as np
import tensorflow as tf
import threading
from timeit import timeit
# from concurrent.futures import ProcessPoolExecutor
class Runner:
def __init__(self, func, runs, concurrent=100):
self._condition = threading.Condition()
self._concurrent = concurrent
self._active = 0
self._runs = runs
self._func = func
self._finished = 0
self._success = 0
self._error = 0
def spawn(self):
for i in range(self._runs):
self.launch()
if self._finished % 100 == 0 and self._finished > 0:
print "Finish %s runs" % self._finished
print "done launching"
def launch(self):
with self._condition:
if self._active >= self._concurrent:
self._condition.wait()
self._active += 1
f = self._func()
f.add_done_callback(self._done)
def _done(self, f):
with self._condition:
try:
f.result()
self._success += 1
except Exception as ex:
print ex
self._error += 1
self._active -= 1
self._finished += 1
self._condition.notify()
def wait(self):
with self._condition:
while self._finished < self._runs:
self._condition.wait()
def convert_to_request(inputs):
request = predict_pb2.PredictRequest()
request.model_spec.name = "peng_unused_model_name"
for key, value in inputs.iteritems():
value = np.array(value)
if value.dtype == np.float64:
value = value.astype(np.float32)
elif value.dtype == np.int64:
value = value.astype(np.int32)
request.inputs[key].CopyFrom(tf.contrib.util.make_tensor_proto(value, shape=value.shape))
return request
def main():
import argparse
import json
parser = argparse.ArgumentParser(description='TF gRPC Benchmark')
parser.add_argument("--input-file", help="Path to input file (JSON format)", required=True)
parser.add_argument("-n", default=1000, type=int)
parser.add_argument("-c", default=10, type=int)
parser.add_argument("--host", default="localhost")
parser.add_argument("--port", default=9998, type=int)
parser.add_argument("--max-timeout", default=1.0, type=float)
ns = parser.parse_args()
with open(ns.input_file, "r") as f:
inputs = json.load(f)
print "Running %s runs against %s:%s" % (
ns.n,
ns.host,
ns.port
)
request_pb2 = convert_to_request(inputs)
def predict():
channel = grpc.beta.implementations.insecure_channel(ns.host, ns.port)
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
return stub.Predict.future(request_pb2, ns.max_timeout)
def run():
runner = Runner(predict, ns.n, ns.c)
runner.spawn()
runner.wait()
print "success %s / error %s" % (runner._success, runner._error)
print timeit(run, number=1)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment