Last active
August 7, 2017 08:10
-
-
Save l225li/ed7ff0763d13aad8312dbe2c411fd6c6 to your computer and use it in GitHub Desktop.
This is helper function to extract convolutional features of images using pre-trained VGG16 model. This makes it possible to do feature extraction of big size of images on a CPU machine.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import bcolz | |
import csv | |
import shutil | |
def save_array(fname, arr): | |
"""Helper function to save numpy array arr to fname(.dat)""" | |
c=bcolz.carray(arr, rootdir=fname, mode='w') | |
c.flush() | |
def load_array(fname): | |
"""Helper function to load numpy arry from file fnmame(.dat)""" | |
return bcolz.open(fname)[:] | |
def extract_features(path_in, path_out, model): | |
"""This function extract features using the given convolutional model | |
to be used as input features for upper layer models. | |
Args: | |
path_in (str): directory path of input images | |
path_out (str): directory path of output dat files | |
model (Model): model with convolutional layers to extract features | |
Returns: | |
Features extracted and labels. Also will save them as dat files | |
at ``path_out`` | |
""" | |
# creates directory `path_out` if not already existed | |
if not os.path.exists(path_out): | |
os.makedirs(path_out) | |
batches = get_batches(path_in, batch_size=1, shuffle=False) | |
labels = batches.classes | |
n = batches.samples | |
save_array(path_out+'labels.dat', labels) | |
# due to the limited ram space, we will the features of samples one by one | |
for i in range(n): | |
features = model.predict(next(batches)[0]) | |
save_array(path_out+'{}.dat'.format(i), features) | |
# now merge the above files to one array | |
features = np.empty_like(features) | |
for i in range(n): | |
f = load_array(path_out+'{}.dat'.format(i)) | |
features = np.concatenate((features, f), axis=0) | |
# features[0] is a placeholder of all zeros, we want to remove it | |
# before saving | |
features = np.delete(features, 0, 0) | |
save_array(path_out+'features.dat', features) | |
# delete the individual dat files | |
for i in range(n): | |
shutil.rmtree(path_out+'{}.dat'.format(i)) | |
return features, labels |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment