Skip to content

Instantly share code, notes, and snippets.

@gu-ma
Created December 16, 2018 20:39
Show Gist options
  • Save gu-ma/c9dc7024b8a72388aca32e4e91ea143e to your computer and use it in GitHub Desktop.
Save gu-ma/c9dc7024b8a72388aca32e4e91ea143e to your computer and use it in GitHub Desktop.
Example of api to serve a word-rnn-tensorflow model
# Clone and train: word-rnn-tensorflow
# then create a file called 'api.py' and run: `python api.py --save_dir='path/to/model'`
# you can call http://0.0.0.0:5002/generate?n=1000&prime='Life is '
#
#
import argparse
import os
from six.moves import cPickle
#
import tensorflow as tf
from model import Model
#
from flask import Flask, request
from flask_restful import Resource, Api
from flask import jsonify
app = Flask(__name__)
api = Api(app)
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--save_dir', type=str, default='save',
help='model directory to store checkpointed models')
args = parser.parse_args()
def loadModel(save_dir):
with open(os.path.join(save_dir, 'config.pkl'), 'rb') as f:
saved_args = cPickle.load(f)
with open(os.path.join(save_dir, 'words_vocab.pkl'), 'rb') as f:
words, vocab = cPickle.load(f)
model = Model(saved_args, True)
return model, words, vocab
def sampleModel(model, words, vocab, n, prime, sample, pick, width):
with tf.Session() as sess:
tf.global_variables_initializer().run()
saver = tf.train.Saver(tf.global_variables())
ckpt = tf.train.get_checkpoint_state(args.save_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
s = model.sample(sess, words, vocab, n, prime, sample, pick, width)
print(s)
result = {'sample': s}
return result
class GenerateSample(Resource):
def get(self):
n = (
int(request.args.get('n'))
if request.args.get('n')
else 50
)
prime = (
request.args.get('prime')
if request.args.get('prime')
else u''
)
sample = (
int(request.args.get('sample'))
if request.args.get('sample')
else 1
)
pick = (
int(request.args.get('pick'))
if request.args.get('pick')
else 1
)
width = (
int(request.args.get('width'))
if request.args.get('width')
else 4
)
result = sampleModel(model, words, vocab, n, prime, sample, pick, width)
return jsonify(result)
api.add_resource(GenerateSample, '/generate')
if __name__ == '__main__':
# Load model
model, words, vocab = loadModel(args.save_dir)
app.config["JSON_SORT_KEYS"] = False
app.run(host='0.0.0.0', port=5002)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment