Created
June 7, 2017 12:10
-
-
Save Nikhil-Kasukurthi/bf6af2ac4ece2c3c578064a374e0ad26 to your computer and use it in GitHub Desktop.
Leveraging on transfer learning for image classification using Keras.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
mkdir data | |
cd data | |
mkdir train | |
mkdir test | |
mkdir validation | |
cd train | |
mkdir roses | |
mkdir dandelion | |
mkdir sunflowers | |
mkdir tulips | |
mkdir daisy | |
cd .. | |
cd test | |
mkdir roses | |
mkdir dandelion | |
mkdir sunflowers | |
mkdir tulips | |
mkdir daisy | |
cd .. | |
cd validation | |
mkdir roses | |
mkdir dandelion | |
mkdir sunflowers | |
mkdir tulips | |
mkdir daisy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras import applications | |
from keras.layers import Dropout, Flatten, Dense | |
from keras.preprocessing.image import ImageDataGenerator | |
from keras.models import Sequential, load_model | |
from keras.optimizers import SGD, Adagrad | |
from keras import regularizers | |
import os | |
import pickle | |
import numpy as np | |
train_y, val_y, test_y = pickle.load(open('preprocess.p','rb')) | |
img_width, img_height = 224, 224 | |
top_model_weights_path = 'bottleneck_fc_model.h5' | |
batch_size = 16 | |
epochs = 50 | |
def save_bottleneck_features(): | |
model = applications.VGG19(weights = 'imagenet', include_top = False) | |
datagen = ImageDataGenerator( | |
rescale=1./255, | |
rotation_range = 40, | |
width_shift_range = 0.2, | |
height_shift_range = 0.2, | |
shear_range = 0.2, | |
zoom_range = 0.2, | |
fill_mode = 'nearest' | |
) | |
generator = datagen.flow_from_directory( | |
'./data/train', | |
target_size = (img_width, img_height), | |
batch_size = batch_size, | |
class_mode = None, | |
shuffle = True | |
) | |
bottleneck_features_train = model.predict_generator(generator, len(train_y)//batch_size) | |
np.save(open('bottleneck_features_train.npy','wb'), bottleneck_features_train) | |
generator = datagen.flow_from_directory( | |
'./data/validation', | |
target_size = (img_width, img_height), | |
batch_size = batch_size, | |
class_mode = None, | |
shuffle = True | |
) | |
bottleneck_features_val = model.predict_generator(generator, len(val_y)//batch_size) | |
np.save(open('bottleneck_features_val.npy','wb'), bottleneck_features_val) | |
def one_hot_encode(labels): | |
classes_dict = {'roses':0, | |
'dandelion':1, | |
'sunflowers':2, | |
'tulips':3, | |
'daisy':4 | |
} | |
labels_one_hot = np.zeros([len(labels), len(classes_dict)]) | |
for index, value in enumerate(labels): | |
labels_one_hot[index][classes_dict.get(value)] = 1 | |
return labels_one_hot | |
def train_top_model(): | |
train_data = np.load(open('bottleneck_features_train.npy','rb')) | |
val_data = np.load(open('bottleneck_features_val.npy','rb')) | |
model = Sequential() | |
model.add(Flatten(input_shape = train_data.shape[1:])) | |
model.add(Dense(256, activation='relu')) | |
model.add(Dropout(0.5)) | |
model.add(Dense(5, activation='softmax')) | |
sgd = SGD(lr=1e-4) | |
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy']) | |
model.fit(train_data, one_hot_encode(train_y), | |
epochs=epochs, | |
batch_size=batch_size, | |
validation_data=(val_data, one_hot_encode(val_y))) | |
model.save(top_model_weights_path) | |
print('Model Saved') | |
#save_bottleneck_features() | |
train_top_model() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from sklearn.model_selection import StratifiedShuffleSplit | |
import pickle | |
import tarfile | |
import shutil | |
from urllib.request import urlretrieve | |
from os.path import isfile, isdir | |
from tqdm import tqdm | |
import numpy as np | |
dataset_folder_path = 'flower_photos' | |
class DLProgress(tqdm): | |
last_block = 0 | |
def hook(self, block_num=1, block_size=1, total_size=None): | |
self.total = total_size | |
self.update((block_num - self.last_block) * block_size) | |
self.last_block = block_num | |
if not isfile('flower_photos.tar.gz'): | |
with DLProgress(unit='B', unit_scale=True, miniters=1, desc='Flowers Dataset') as pbar: | |
urlretrieve( | |
'http://download.tensorflow.org/example_images/flower_photos.tgz', | |
'flower_photos.tar.gz', | |
pbar.hook) | |
if not isdir(dataset_folder_path): | |
with tarfile.open('flower_photos.tar.gz') as tar: | |
tar.extractall() | |
tar.close() | |
data_dir = 'flower_photos/' | |
directory = './' | |
contents = os.listdir(data_dir) | |
classes = [each for each in contents if os.path.isdir(data_dir + each)] | |
file_paths = [] | |
labels = [] | |
for each in classes: | |
class_path = data_dir + each | |
files = os.listdir(class_path) | |
for i, file in enumerate(files): | |
path = os.path.join(directory+'/'+class_path, file) | |
target = os.path.join(directory, class_path +'/'+each+'_'+ str(i)+'.jpg') | |
os.rename(path, target) | |
files = os.listdir(class_path) | |
for file in files: | |
file_paths.append(os.path.join(directory, class_path +'/'+file)) | |
labels.append(each) | |
print(len(file_paths), len(labels)) | |
ss = StratifiedShuffleSplit(n_splits=1, test_size=0.2) | |
train_idx, val_idx = next(ss.split(file_paths, labels)) | |
half_val_len = int(len(val_idx)/2) | |
print(half_val_len, val_idx[half_val_len:]) | |
val_idx, test_idx = val_idx[:half_val_len], val_idx[half_val_len:] | |
print('Train',len(train_idx)) | |
print('Test',len(test_idx)) | |
print('Validation',len(val_idx)) | |
train_y = [] | |
for i in train_idx: | |
shutil.copy(file_paths[i],'./data/train/'+labels[i]) | |
train_y.append(labels[i]) | |
val_y = [] | |
for j in val_idx: | |
shutil.copy(file_paths[j],'./data/validation/'+labels[j]) | |
val_y.append(labels[j]) | |
test_y = [] | |
for k in test_idx: | |
shutil.copy(file_paths[k],'./data/test/'+labels[k]) | |
test_y.append(labels[k]) | |
pickle.dump([train_y, val_y, test_y], open('preprocess.p','wb')) | |
print('Data converted into train, test and validation') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment