Skip to content

Instantly share code, notes, and snippets.

@gu-ma
Last active December 16, 2018 20:40
Show Gist options
  • Save gu-ma/39c267dcdd45ee671e54a7962eb7fc9b to your computer and use it in GitHub Desktop.
Save gu-ma/39c267dcdd45ee671e54a7962eb7fc9b to your computer and use it in GitHub Desktop.
Example of api to serve a char-rnn-tensorflow model
# Clone and train: https://github.com/sherjilozair/char-rnn-tensorflow/
# then create a file called 'api.py' and run: `python api.py --save_dir='path/to/model'`
# you can call http://0.0.0.0:5002/generate?n=1000&prime='Life is '
#
import argparse
import os
from six.moves import cPickle
#
import tensorflow as tf
from model import Model
#
from flask import Flask, request
from flask_restful import Resource, Api
from flask import jsonify
app = Flask(__name__)
api = Api(app)
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--save_dir', type=str, default='save',
help='model directory to store checkpointed models')
args = parser.parse_args()
def loadModel(save_dir):
with open(os.path.join(save_dir, 'config.pkl'), 'rb') as f:
saved_args = cPickle.load(f)
with open(os.path.join(save_dir, 'chars_vocab.pkl'), 'rb') as f:
chars, vocab = cPickle.load(f)
model = Model(saved_args, training=False)
return model, chars, vocab
def sampleModel(model, chars, vocab, n, prime, sample):
#Use most frequent char if no prime is given
if prime == '':
prime = chars[0]
with tf.Session() as sess:
tf.global_variables_initializer().run()
saver = tf.train.Saver(tf.global_variables())
ckpt = tf.train.get_checkpoint_state(args.save_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
s = model.sample(sess, chars, vocab, n, prime,sample).encode('utf-8')
print(s.decode('utf-8'))
result = {'sample': s.decode("utf-8")}
return result
class GenerateSample(Resource):
def get(self):
n = (
int(request.args.get('n'))
if request.args.get('n')
else 50
)
prime = (
request.args.get('prime')
if request.args.get('prime')
else u''
)
sample = (
int(request.args.get('sample'))
if request.args.get('sample')
else 1
)
result = sampleModel(model, chars, vocab, n, prime, sample)
return jsonify(result)
api.add_resource(GenerateSample, '/generate')
if __name__ == '__main__':
# Load model
model, chars, vocab = loadModel(args.save_dir)
app.config["JSON_SORT_KEYS"] = False
app.run(host='0.0.0.0', port=5002)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment