Created
November 30, 2020 06:53
-
-
Save PhanDuc/0dbc6a648202b8dc335f9f09019d539f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastapi import FastAPI, File, UploadFile, HTTPException | |
from PIL import Image | |
from pydantic import BaseModel | |
from tensorflow.keras.models import load_model | |
from typing import List | |
import io | |
import numpy as np | |
import sys | |
import uvicorn | |
# load model | |
#filepath = "./saved_model" | |
filepath = '/media/Another/Computer_Science_Project/fastapi_learning/keras_fastapi/saved_model' | |
# model = load_model(filepath, compile=True) | |
model = load_model(filepath) | |
# get the input shape for the model layer | |
input_shape = model.layers[0].input_shape | |
# Define the Response | |
class Prediction(BaseModel): | |
filename: str | |
contenttype: str | |
prediction: List[float] = [] | |
likely_class: int | |
# define the fastAPI | |
app = FastAPI() | |
# define response | |
@app.get("/") | |
def root_route(): | |
return {'error': 'Use GET /prediction instead of the root route!'} | |
# define the prediction route | |
@app.post("/prediction/", response_model=Prediction) | |
async def prediction_route(file: UploadFile = File(...)): | |
# ensure that this is an image | |
if file.content_type.startswith("image/") is False: | |
raise HTTPException(status_code=400, | |
detail=f'File \'{file.filename}\' is not an image.' | |
) | |
try: | |
# read image contain | |
contents = await file.read() | |
pil_image = Image.open(io.BytesIO(contents)) | |
# resize image to expected input shape | |
pil_image = pil_image.resize((input_shape[1], input_shape[2])) | |
# convert image into grayscale | |
if input_shape[3] and input_shape[3] == 1: | |
pil_image = pil_image.convert('L') | |
# convert imgae to numpy format | |
numpy_image = np.array(pil_image).reshape((input_shape[1], input_shape[2], input_shape[3])) | |
# scale data | |
numpy_image = numpy_image / 255.0 | |
# generate prediction | |
prediction_array = np.array([numpy_image]) | |
predictions = model.predict(prediction_array) | |
prediction = predictions[0] | |
likely_class = np.argmax(prediction) | |
return { | |
"filename": file.filename, | |
"contenttype": file.content_type, | |
"prediction": prediction.tolist(), | |
"likely_class": likely_class | |
} | |
except: | |
e = sys.exc_info()[1] | |
raise HTTPException(status_code=500, detail=str(e)) | |
if __name__ == '__main__': | |
uvicorn.run(app, debug=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment