Last active
July 16, 2021 14:39
-
-
Save danyashorokh/10e4ecad1f2e0a84b0050740a207a4f7 to your computer and use it in GitHub Desktop.
[KERAS] Feature extractors
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.preprocessing import image | |
from keras.applications.vgg16 import VGG16, preprocess_input as preprocess_input_vgg | |
from keras.applications.inception_v3 import InceptionV3, preprocess_input as preprocess_input_inception | |
from keras.applications.mobilenet_v2 import MobileNetV2, preprocess_input as preprocess_input_mobilenet | |
from keras.applications.resnet import ResNet50, preprocess_input as preprocess_input_resnet | |
from keras.models import Model, load_model | |
from keras.layers import Input, GlobalAveragePooling2D, GlobalMaxPooling2D | |
import numpy as np | |
import cv2 | |
class FeatureExtractor(): | |
def __init__(self, model_name='mobilenet_v2', weights=None, include_top=False, add_max=False, | |
add_avg=False, get_layer=None, convert_to_rgb=False): | |
"""Either load pretrained from imagenet, or load our saved | |
weights from our own training.""" | |
self.model_name = model_name | |
self.weights = weights | |
self.include_top = include_top | |
self.get_layer = get_layer | |
self.convert_to_rgb = convert_to_rgb | |
self.add_max = add_max | |
self.add_avg = add_avg | |
# dict with load function, input size and preprocess input function | |
self.models = { | |
'mobilenet_v2': [MobileNetV2, (224, 224), preprocess_input_mobilenet], | |
'vgg19': [VGG16, (224, 224), preprocess_input_vgg], | |
'resnet50': [ResNet50, (224, 224), preprocess_input_resnet], | |
'inception_v3': [InceptionV3, (299, 299), preprocess_input_inception], | |
} | |
# check model name | |
if self.model_name not in self.models.keys(): | |
raise ValueError(f'Unknown model {self.model_name}. Use one of {self.models.keys()}') | |
# check add layers flags: | |
if self.add_max and self.add_max: | |
raise ValueError(f'Both flags (max, avg) are True') | |
self.get_model = self.models[self.model_name][0] | |
self.target_size = self.models[self.model_name][1] | |
self.preprocess_input = self.models[self.model_name][2] | |
if weights is None: | |
base_model = self.get_model(weights='imagenet', include_top=self.include_top) | |
else: | |
base_model = load_model(weights) | |
# Get custom layer layer. | |
if self.get_layer is not None: | |
base_model = Model( | |
inputs=base_model.input, | |
outputs=base_model.get_layer(self.get_layer).output | |
) | |
# get output from base model | |
x = base_model.output | |
if self.add_avg: | |
# add a global spatial average pooling layer | |
x = GlobalAveragePooling2D()(x) | |
if self.add_max: | |
# add a global spatial max pooling layer | |
x = GlobalMaxPooling2D()(x) | |
# define output | |
outputs = x | |
# this is the model we will train | |
self.model = Model(inputs=base_model.input, outputs=outputs) | |
def extract(self, image_path): | |
if type(image_path) == str: | |
img = image.load_img(image_path, target_size=self.target_size) | |
x = image.img_to_array(img) | |
else: | |
x = cv2.resize(image_path, self.target_size) | |
if self.convert_to_rgb: | |
x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB) | |
# x = x[:, :, ::-1] | |
x = np.expand_dims(x, axis=0) | |
x = self.preprocess_input(x) | |
features = self.model.predict(x)[0] | |
return features |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment