Skip to content

Instantly share code, notes, and snippets.

@ywolff
Created January 6, 2020 11:31
Show Gist options
  • Save ywolff/39c913d6377db409a9b2f383a52e79e7 to your computer and use it in GitHub Desktop.
Save ywolff/39c913d6377db409a9b2f383a52e79e7 to your computer and use it in GitHub Desktop.
import threading
import time
from queue import Empty, Queue
import numpy as np
from flask import Flask, request as flask_request
from build_big_model import build_big_model
BATCH_SIZE = 20
BATCH_TIMEOUT = 0.5
CHECK_INTERVAL = 0.01
model = build_big_model()
requests_queue = Queue()
app = Flask(__name__)
def handle_requests_by_batch():
while True:
requests_batch = []
while not (
len(requests_batch) > BATCH_SIZE or
(len(requests_batch) > 0 and time.time() - requests_batch[0]['time'] > BATCH_TIMEOUT)
):
try:
requests_batch.append(requests_queue.get(timeout=CHECK_INTERVAL))
except Empty:
continue
batch_inputs = np.array([request['input'] for request in requests_batch])
batch_outputs = model.predict(batch_inputs)
for request, output in zip(requests_batch, batch_outputs):
request['output'] = output
threading.Thread(target=handle_requests_by_batch).start()
@app.route('/predict', methods=['POST'])
def predict():
received_input = np.array(flask_request.json['instances'][0])
request = {'input': received_input, 'time': time.time()}
requests_queue.put(request)
while 'output' not in request:
time.sleep(CHECK_INTERVAL)
return {'predictions': request['output'].tolist()}
if __name__ == '__main__':
app.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment