Created
January 6, 2020 11:31
-
-
Save ywolff/39c913d6377db409a9b2f383a52e79e7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import threading | |
import time | |
from queue import Empty, Queue | |
import numpy as np | |
from flask import Flask, request as flask_request | |
from build_big_model import build_big_model | |
BATCH_SIZE = 20 | |
BATCH_TIMEOUT = 0.5 | |
CHECK_INTERVAL = 0.01 | |
model = build_big_model() | |
requests_queue = Queue() | |
app = Flask(__name__) | |
def handle_requests_by_batch(): | |
while True: | |
requests_batch = [] | |
while not ( | |
len(requests_batch) > BATCH_SIZE or | |
(len(requests_batch) > 0 and time.time() - requests_batch[0]['time'] > BATCH_TIMEOUT) | |
): | |
try: | |
requests_batch.append(requests_queue.get(timeout=CHECK_INTERVAL)) | |
except Empty: | |
continue | |
batch_inputs = np.array([request['input'] for request in requests_batch]) | |
batch_outputs = model.predict(batch_inputs) | |
for request, output in zip(requests_batch, batch_outputs): | |
request['output'] = output | |
threading.Thread(target=handle_requests_by_batch).start() | |
@app.route('/predict', methods=['POST']) | |
def predict(): | |
received_input = np.array(flask_request.json['instances'][0]) | |
request = {'input': received_input, 'time': time.time()} | |
requests_queue.put(request) | |
while 'output' not in request: | |
time.sleep(CHECK_INTERVAL) | |
return {'predictions': request['output'].tolist()} | |
if __name__ == '__main__': | |
app.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment