Skip to content

Instantly share code, notes, and snippets.

@dmc5179
Last active September 20, 2019 15:16
Show Gist options
  • Save dmc5179/7d7e66667494f2f0379ac75a4ac78220 to your computer and use it in GitHub Desktop.
Save dmc5179/7d7e66667494f2f0379ac75a4ac78220 to your computer and use it in GitHub Desktop.
Mask-r-cnn model serving
"""
Mask R-CNN
Copyright (c) 2019
Licensed under the MIT License (see LICENSE for details)
------------------------------------------------------------
Usage: import the module (see Jupyter notebooks for examples), or run from
the command line as such:
# Apply color splash to an image
python3 balloon.py splash --weights=/path/to/weights/file.h5 --image=<URL or path to file>
"""
import os
import sys
import json
import datetime
import numpy as np
import skimage.draw
import warnings
warnings.filterwarnings("ignore")
SHIP_CLASS_NAME = 'ship'
IMAGE_WIDTH = 768
IMAGE_HEIGHT = 768
SHAPE = (IMAGE_WIDTH, IMAGE_HEIGHT)
# Root directory of the project
#ROOT_DIR = os.path.abspath("../../")
ROOT_DIR = "/home/ec2-user/workspace/Mask_RCNN/"
# Import Mask RCNN
sys.path.append(ROOT_DIR) # To find local version of the library
from mrcnn.config import Config
from mrcnn import model as modellib, utils
# Path to trained weights file
#COCO_WEIGHTS_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")
COCO_WEIGHTS_PATH = "/data/keras_model/mask_rcnn_asdc_gpu_0004.h5"
# Directory to save logs and model checkpoints, if not provided
# through the command line argument --logs
#DEFAULT_LOGS_DIR = os.path.join(ROOT_DIR, "logs")
DEFAULT_LOGS_DIR = "/tmp/ship_logs/"
############################################################
# Configurations
############################################################
class AirbusShipDetectionChallengeGPUConfig(Config):
"""
Configuration of Airbus Ship Detection Challenge Dataset
Overrides values in the base Config class.
From https://github.com/samlin001/Mask_RCNN/blob/master/mrcnn/config.py
"""
# https://www.kaggle.com/docs/kernels#technical-specifications
NAME = 'ASDC_GPU'
# NUMBER OF GPUs to use.
GPU_COUNT = 1
IMAGES_PER_GPU = 1
NUM_CLASSES = 2 # ship or background
IMAGE_MIN_DIM = IMAGE_WIDTH
IMAGE_MAX_DIM = IMAGE_WIDTH
STEPS_PER_EPOCH = 300
VALIDATION_STEPS = 50
SAVE_BEST_ONLY = True
# Minimum probability value to accept a detected instance
# ROIs below this threshold are skipped
DETECTION_MIN_CONFIDENCE = 0.95
# Non-maximum suppression threshold for detection
# Keep it small to merge overlapping ROIs
DETECTION_NMS_THRESHOLD = 0.05
def color_splash(image, mask):
"""Apply color splash effect.
image: RGB image [height, width, 3]
mask: instance segmentation mask [height, width, instance count]
Returns result image.
"""
# Make a grayscale copy of the image. The grayscale copy still
# has 3 RGB channels, though.
gray = skimage.color.gray2rgb(skimage.color.rgb2gray(image)) * 255
# Copy color pixels from the original color image where mask is set
if mask.shape[-1] > 0:
# We're treating all instances as one, so collapse the mask into one layer
mask = (np.sum(mask, -1, keepdims=True) >= 1)
splash = np.where(mask, image, gray).astype(np.uint8)
else:
splash = gray.astype(np.uint8)
return splash
def color_ships(image, rois, scores):
from PIL import Image, ImageFont, ImageDraw, ImageEnhance
source_img = Image.open(image).convert("RGBA")
draw = ImageDraw.Draw(source_img)
#[(x0, y0), (x1, y1)] or [x0, y0, x1, y1].
for i in range(len(rois)):
draw.rectangle(((rois[i][1], rois[i][0]), (rois[i][3], rois[i][2])), outline="red")
#draw.rectangle(((rois[0][1], rois[0][0]), (rois[0][3], rois[0][2])), outline="red")
draw.text((rois[i][1], rois[i][0]), str(scores[i]), font=ImageFont.load_default(), fill="green")
#draw.text((20, 70), "something123", font=ImageFont.truetype("font_path123"))
file_name = "/data/kaggle/output/splash_{:%Y%m%dT%H%M%S}.png".format(datetime.datetime.now())
source_img.save(file_name, "PNG")
print("Saved to ", file_name)
return file_name
def detect_and_color_splash(model, image_path=None):
assert image_path or video_path
# Run model detection and generate the color splash effect
print("Running on {}".format(image_path))
# Read image
image = skimage.io.imread(image_path)
# Detect objects
r = model.detect([image], verbose=0)[0]
#r = model.detect([image], verbose=0)
#print("Scores: ", r['scores'])
#print("ROIS: ", r['rois'])
# Color splash
#splash = color_splash(image, r['masks'])
if len(r['rois']) > 0:
splash = color_ships(image_path, r['rois'], r['scores'])
# Save output
#file_name = "splash_{:%Y%m%dT%H%M%S}.png".format(datetime.datetime.now())
#skimage.io.imsave(file_name, splash)
#print("Saved to ", file_name)
return splash
##########################################################################
from flask import Flask, request
import os
import boto3
import botocore
client = boto3.client('s3') #low-level functional API
filename = "/tmp/serving.log"
#file = open(filename,'a')
app = Flask(__name__)
# Variables
# Source S3 bucket
# Destination S3 bucket
@app.route('/detect', methods=['POST'])
def get_user():
file = open(filename,'a')
#####
config = AirbusShipDetectionChallengeGPUConfig()
model = modellib.MaskRCNN(mode="inference", config=config, model_dir=DEFAULT_LOGS_DIR)
weights_path = COCO_WEIGHTS_PATH
# Load weights
model.load_weights(weights_path, by_name=True)
#####
# Jsonify the data in the POST request
content = request.get_json()
# Where to store the file locally from S3
image_filename = os.path.join('/tmp/', content['object_name'])
# Pull the file from S3
client.download_file(content['bucket_name'], content['object_name'], image_filename)
# Run Inference on the image
splash_file = detect_and_color_splash(model, image_path=image_filename)
# Remove the temp file downloaded from S3
os.remove(image_filename)
# Uploaded the hitbox image to S3
new_name = "splash_" + content['object_name'] + ".png"
client.upload_file(splash_file, content['bucket_name'], new_name)
# Remove the tempory output hitbox image from local store
os.remove(splash_file)
return "OK"
##############################################################################
if __name__ == '__main__':
app.run(host= '0.0.0.0',debug=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment