Created
August 28, 2019 19:29
-
-
Save rayryeng/0b4906ef0fd8c82a83cf6414359f9147 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division, print_function | |
# coding=utf-8 | |
import sys | |
import os | |
import glob | |
import re | |
import numpy as np | |
# Keras | |
from keras.applications.imagenet_utils import preprocess_input, decode_predictions | |
from keras.models import load_model | |
from keras.preprocessing import image | |
# Flask utils | |
from flask import Flask, redirect, url_for, request, render_template | |
from werkzeug.utils import secure_filename | |
from gevent.pywsgi import WSGIServer | |
# Mod | |
import tensorflow as tf | |
# Define a flask app | |
app = Flask(__name__) | |
# Model saved with Keras model.save() | |
MODEL_PATH = 'models/your_model.h5' | |
# Load your trained model | |
# model = load_model(MODEL_PATH) | |
# model._make_predict_function() # Necessary | |
# print('Model loaded. Start serving...') | |
# You can also use pretrained model from Keras | |
# Check https://keras.io/applications/ | |
from keras.applications.resnet50 import ResNet50 | |
model = ResNet50(weights='imagenet') | |
graph = tf.get_default_graph() # Change | |
print('Model loaded. Check http://127.0.0.1:5000/') | |
def model_predict(img_path, model): | |
img = image.load_img(img_path, target_size=(224, 224)) | |
# Preprocessing the image | |
x = image.img_to_array(img) | |
# x = np.true_divide(x, 255) | |
x = np.expand_dims(x, axis=0) | |
# Be careful how your trained model deals with the input | |
# otherwise, it won't make correct prediction! | |
x = preprocess_input(x, mode='caffe') | |
with graph.as_default(): # Change | |
preds = model.predict(x) | |
return preds | |
@app.route('/', methods=['GET']) | |
def index(): | |
# Main page | |
return render_template('index.html') | |
@app.route('/predict', methods=['GET', 'POST']) | |
def upload(): | |
if request.method == 'POST': | |
# Get the file from post request | |
f = request.files['image'] | |
# Save the file to ./uploads | |
basepath = os.path.dirname(__file__) | |
file_path = os.path.join( | |
basepath, 'uploads', secure_filename(f.filename)) | |
f.save(file_path) | |
# Make prediction | |
preds = model_predict(file_path, model) | |
# Process your result for human | |
# pred_class = preds.argmax(axis=-1) # Simple argmax | |
pred_class = decode_predictions(preds, top=1) # ImageNet Decode | |
result = str(pred_class[0][0][1]) # Convert to string | |
return result | |
return None | |
if __name__ == '__main__': | |
# app.run(port=5002, debug=True) | |
# Serve the app with gevent | |
http_server = WSGIServer(('0.0.0.0', 5000), app) | |
http_server.serve_forever() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment