Last active
September 20, 2019 15:16
-
-
Save dmc5179/7d7e66667494f2f0379ac75a4ac78220 to your computer and use it in GitHub Desktop.
Mask-r-cnn model serving
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Mask R-CNN | |
Copyright (c) 2019 | |
Licensed under the MIT License (see LICENSE for details) | |
------------------------------------------------------------ | |
Usage: import the module (see Jupyter notebooks for examples), or run from | |
the command line as such: | |
# Apply color splash to an image | |
python3 balloon.py splash --weights=/path/to/weights/file.h5 --image=<URL or path to file> | |
""" | |
import os | |
import sys | |
import json | |
import datetime | |
import numpy as np | |
import skimage.draw | |
import warnings | |
warnings.filterwarnings("ignore") | |
SHIP_CLASS_NAME = 'ship' | |
IMAGE_WIDTH = 768 | |
IMAGE_HEIGHT = 768 | |
SHAPE = (IMAGE_WIDTH, IMAGE_HEIGHT) | |
# Root directory of the project | |
#ROOT_DIR = os.path.abspath("../../") | |
ROOT_DIR = "/home/ec2-user/workspace/Mask_RCNN/" | |
# Import Mask RCNN | |
sys.path.append(ROOT_DIR) # To find local version of the library | |
from mrcnn.config import Config | |
from mrcnn import model as modellib, utils | |
# Path to trained weights file | |
#COCO_WEIGHTS_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5") | |
COCO_WEIGHTS_PATH = "/data/keras_model/mask_rcnn_asdc_gpu_0004.h5" | |
# Directory to save logs and model checkpoints, if not provided | |
# through the command line argument --logs | |
#DEFAULT_LOGS_DIR = os.path.join(ROOT_DIR, "logs") | |
DEFAULT_LOGS_DIR = "/tmp/ship_logs/" | |
############################################################ | |
# Configurations | |
############################################################ | |
class AirbusShipDetectionChallengeGPUConfig(Config): | |
""" | |
Configuration of Airbus Ship Detection Challenge Dataset | |
Overrides values in the base Config class. | |
From https://github.com/samlin001/Mask_RCNN/blob/master/mrcnn/config.py | |
""" | |
# https://www.kaggle.com/docs/kernels#technical-specifications | |
NAME = 'ASDC_GPU' | |
# NUMBER OF GPUs to use. | |
GPU_COUNT = 1 | |
IMAGES_PER_GPU = 1 | |
NUM_CLASSES = 2 # ship or background | |
IMAGE_MIN_DIM = IMAGE_WIDTH | |
IMAGE_MAX_DIM = IMAGE_WIDTH | |
STEPS_PER_EPOCH = 300 | |
VALIDATION_STEPS = 50 | |
SAVE_BEST_ONLY = True | |
# Minimum probability value to accept a detected instance | |
# ROIs below this threshold are skipped | |
DETECTION_MIN_CONFIDENCE = 0.95 | |
# Non-maximum suppression threshold for detection | |
# Keep it small to merge overlapping ROIs | |
DETECTION_NMS_THRESHOLD = 0.05 | |
def color_splash(image, mask): | |
"""Apply color splash effect. | |
image: RGB image [height, width, 3] | |
mask: instance segmentation mask [height, width, instance count] | |
Returns result image. | |
""" | |
# Make a grayscale copy of the image. The grayscale copy still | |
# has 3 RGB channels, though. | |
gray = skimage.color.gray2rgb(skimage.color.rgb2gray(image)) * 255 | |
# Copy color pixels from the original color image where mask is set | |
if mask.shape[-1] > 0: | |
# We're treating all instances as one, so collapse the mask into one layer | |
mask = (np.sum(mask, -1, keepdims=True) >= 1) | |
splash = np.where(mask, image, gray).astype(np.uint8) | |
else: | |
splash = gray.astype(np.uint8) | |
return splash | |
def color_ships(image, rois, scores): | |
from PIL import Image, ImageFont, ImageDraw, ImageEnhance | |
source_img = Image.open(image).convert("RGBA") | |
draw = ImageDraw.Draw(source_img) | |
#[(x0, y0), (x1, y1)] or [x0, y0, x1, y1]. | |
for i in range(len(rois)): | |
draw.rectangle(((rois[i][1], rois[i][0]), (rois[i][3], rois[i][2])), outline="red") | |
#draw.rectangle(((rois[0][1], rois[0][0]), (rois[0][3], rois[0][2])), outline="red") | |
draw.text((rois[i][1], rois[i][0]), str(scores[i]), font=ImageFont.load_default(), fill="green") | |
#draw.text((20, 70), "something123", font=ImageFont.truetype("font_path123")) | |
file_name = "/data/kaggle/output/splash_{:%Y%m%dT%H%M%S}.png".format(datetime.datetime.now()) | |
source_img.save(file_name, "PNG") | |
print("Saved to ", file_name) | |
return file_name | |
def detect_and_color_splash(model, image_path=None): | |
assert image_path or video_path | |
# Run model detection and generate the color splash effect | |
print("Running on {}".format(image_path)) | |
# Read image | |
image = skimage.io.imread(image_path) | |
# Detect objects | |
r = model.detect([image], verbose=0)[0] | |
#r = model.detect([image], verbose=0) | |
#print("Scores: ", r['scores']) | |
#print("ROIS: ", r['rois']) | |
# Color splash | |
#splash = color_splash(image, r['masks']) | |
if len(r['rois']) > 0: | |
splash = color_ships(image_path, r['rois'], r['scores']) | |
# Save output | |
#file_name = "splash_{:%Y%m%dT%H%M%S}.png".format(datetime.datetime.now()) | |
#skimage.io.imsave(file_name, splash) | |
#print("Saved to ", file_name) | |
return splash | |
########################################################################## | |
from flask import Flask, request | |
import os | |
import boto3 | |
import botocore | |
client = boto3.client('s3') #low-level functional API | |
filename = "/tmp/serving.log" | |
#file = open(filename,'a') | |
app = Flask(__name__) | |
# Variables | |
# Source S3 bucket | |
# Destination S3 bucket | |
@app.route('/detect', methods=['POST']) | |
def get_user(): | |
file = open(filename,'a') | |
##### | |
config = AirbusShipDetectionChallengeGPUConfig() | |
model = modellib.MaskRCNN(mode="inference", config=config, model_dir=DEFAULT_LOGS_DIR) | |
weights_path = COCO_WEIGHTS_PATH | |
# Load weights | |
model.load_weights(weights_path, by_name=True) | |
##### | |
# Jsonify the data in the POST request | |
content = request.get_json() | |
# Where to store the file locally from S3 | |
image_filename = os.path.join('/tmp/', content['object_name']) | |
# Pull the file from S3 | |
client.download_file(content['bucket_name'], content['object_name'], image_filename) | |
# Run Inference on the image | |
splash_file = detect_and_color_splash(model, image_path=image_filename) | |
# Remove the temp file downloaded from S3 | |
os.remove(image_filename) | |
# Uploaded the hitbox image to S3 | |
new_name = "splash_" + content['object_name'] + ".png" | |
client.upload_file(splash_file, content['bucket_name'], new_name) | |
# Remove the tempory output hitbox image from local store | |
os.remove(splash_file) | |
return "OK" | |
############################################################################## | |
if __name__ == '__main__': | |
app.run(host= '0.0.0.0',debug=False) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment