Last active
December 3, 2020 06:15
-
-
Save sadimanna/698a96ebf4de5b2a207d43bc61e62586 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| VALIDATION_ACCURACY = [] | |
| VALIDAITON_LOSS = [] | |
| save_dir = '/saved_models/' | |
| fold_var = 1 | |
| for train_index, val_index in kf.split(np.zeros(n),Y): | |
| training_data = train_data.iloc[train_index] | |
| validation_data = train_data.iloc[val_index] | |
| train_data_generator = idg.flow_from_dataframe(training_data, directory = image_dir, | |
| x_col = "filename", y_col = "label", | |
| class_mode = "categorical", shuffle = True) | |
| valid_data_generator = idg.flow_from_dataframe(validation_data, directory = image_dir, | |
| x_col = "filename", y_col = "label", | |
| class_mode = "categorical", shuffle = True) | |
| # CREATE NEW MODEL | |
| model = create_new_model() | |
| # COMPILE NEW MODEL | |
| model.compile(loss='categorical_crossentropy', | |
| optimizer=opt, | |
| metrics=['accuracy']) | |
| # CREATE CALLBACKS | |
| checkpoint = tf.keras.callbacks.ModelCheckpoint(save_dir+get_model_name(fold_var), | |
| monitor='val_accuracy', verbose=1, | |
| save_best_only=True, mode='max') | |
| callbacks_list = [checkpoint] | |
| # There can be other callbacks, but just showing one because it involves the model name | |
| # This saves the best model | |
| # FIT THE MODEL | |
| history = model.fit(train_data_generator, | |
| epochs=num_epochs, | |
| callbacks=callbacks_list, | |
| validation_data=valid_data_generator) | |
| #PLOT HISTORY | |
| # : | |
| # : | |
| # LOAD BEST MODEL to evaluate the performance of the model | |
| model.load_weights("/saved_models/model_"+str(fold_var)+".h5") | |
| results = model.evaluate(valid_data_generator) | |
| results = dict(zip(model.metrics_names,results)) | |
| VALIDATION_ACCURACY.append(results['accuracy']) | |
| VALIDATION_LOSS.append(results['loss']) | |
| tf.keras.backend.clear_session() | |
| fold_var += 1 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment