Skip to content

Instantly share code, notes, and snippets.

@sadimanna
Last active December 3, 2020 06:15
Show Gist options
  • Select an option

  • Save sadimanna/698a96ebf4de5b2a207d43bc61e62586 to your computer and use it in GitHub Desktop.

Select an option

Save sadimanna/698a96ebf4de5b2a207d43bc61e62586 to your computer and use it in GitHub Desktop.
VALIDATION_ACCURACY = []
VALIDAITON_LOSS = []
save_dir = '/saved_models/'
fold_var = 1
for train_index, val_index in kf.split(np.zeros(n),Y):
training_data = train_data.iloc[train_index]
validation_data = train_data.iloc[val_index]
train_data_generator = idg.flow_from_dataframe(training_data, directory = image_dir,
x_col = "filename", y_col = "label",
class_mode = "categorical", shuffle = True)
valid_data_generator = idg.flow_from_dataframe(validation_data, directory = image_dir,
x_col = "filename", y_col = "label",
class_mode = "categorical", shuffle = True)
# CREATE NEW MODEL
model = create_new_model()
# COMPILE NEW MODEL
model.compile(loss='categorical_crossentropy',
optimizer=opt,
metrics=['accuracy'])
# CREATE CALLBACKS
checkpoint = tf.keras.callbacks.ModelCheckpoint(save_dir+get_model_name(fold_var),
monitor='val_accuracy', verbose=1,
save_best_only=True, mode='max')
callbacks_list = [checkpoint]
# There can be other callbacks, but just showing one because it involves the model name
# This saves the best model
# FIT THE MODEL
history = model.fit(train_data_generator,
epochs=num_epochs,
callbacks=callbacks_list,
validation_data=valid_data_generator)
#PLOT HISTORY
# :
# :
# LOAD BEST MODEL to evaluate the performance of the model
model.load_weights("/saved_models/model_"+str(fold_var)+".h5")
results = model.evaluate(valid_data_generator)
results = dict(zip(model.metrics_names,results))
VALIDATION_ACCURACY.append(results['accuracy'])
VALIDATION_LOSS.append(results['loss'])
tf.keras.backend.clear_session()
fold_var += 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment