Last active
February 27, 2017 22:46
-
-
Save nqbao/6764a505cc3dc2be1b23bc587d37bf87 to your computer and use it in GitHub Desktop.
TF gRPC Benchmark
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import grpc.beta.implementations | |
import tensorflow_serving.apis.predict_pb2 as predict_pb2 | |
import tensorflow_serving.apis.prediction_service_pb2 as prediction_service_pb2 | |
import numpy as np | |
import tensorflow as tf | |
import threading | |
from timeit import timeit | |
# from concurrent.futures import ProcessPoolExecutor | |
class Runner: | |
def __init__(self, func, runs, concurrent=100): | |
self._condition = threading.Condition() | |
self._concurrent = concurrent | |
self._active = 0 | |
self._runs = runs | |
self._func = func | |
self._finished = 0 | |
self._success = 0 | |
self._error = 0 | |
def spawn(self): | |
for i in range(self._runs): | |
self.launch() | |
if self._finished % 100 == 0 and self._finished > 0: | |
print "Finish %s runs" % self._finished | |
print "done launching" | |
def launch(self): | |
with self._condition: | |
if self._active >= self._concurrent: | |
self._condition.wait() | |
self._active += 1 | |
f = self._func() | |
f.add_done_callback(self._done) | |
def _done(self, f): | |
with self._condition: | |
try: | |
f.result() | |
self._success += 1 | |
except Exception as ex: | |
print ex | |
self._error += 1 | |
self._active -= 1 | |
self._finished += 1 | |
self._condition.notify() | |
def wait(self): | |
with self._condition: | |
while self._finished < self._runs: | |
self._condition.wait() | |
def convert_to_request(inputs): | |
request = predict_pb2.PredictRequest() | |
request.model_spec.name = "peng_unused_model_name" | |
for key, value in inputs.iteritems(): | |
value = np.array(value) | |
if value.dtype == np.float64: | |
value = value.astype(np.float32) | |
elif value.dtype == np.int64: | |
value = value.astype(np.int32) | |
request.inputs[key].CopyFrom(tf.contrib.util.make_tensor_proto(value, shape=value.shape)) | |
return request | |
def main(): | |
import argparse | |
import json | |
parser = argparse.ArgumentParser(description='TF gRPC Benchmark') | |
parser.add_argument("--input-file", help="Path to input file (JSON format)", required=True) | |
parser.add_argument("-n", default=1000, type=int) | |
parser.add_argument("-c", default=10, type=int) | |
parser.add_argument("--host", default="localhost") | |
parser.add_argument("--port", default=9998, type=int) | |
parser.add_argument("--max-timeout", default=1.0, type=float) | |
ns = parser.parse_args() | |
with open(ns.input_file, "r") as f: | |
inputs = json.load(f) | |
print "Running %s runs against %s:%s" % ( | |
ns.n, | |
ns.host, | |
ns.port | |
) | |
request_pb2 = convert_to_request(inputs) | |
def predict(): | |
channel = grpc.beta.implementations.insecure_channel(ns.host, ns.port) | |
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel) | |
return stub.Predict.future(request_pb2, ns.max_timeout) | |
def run(): | |
runner = Runner(predict, ns.n, ns.c) | |
runner.spawn() | |
runner.wait() | |
print "success %s / error %s" % (runner._success, runner._error) | |
print timeit(run, number=1) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment