Created
January 7, 2021 16:59
-
-
Save rozeappletree/fe3161dfd234086e831f3aab1628a9e5 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train(): | |
tf.global_variables_initializer().run() | |
could_load, checkpoint_counter = load() | |
if could_load: | |
start_epoch = (int)(checkpoint_counter / num_batches) | |
start_batch_id = checkpoint_counter - start_epoch * num_batches | |
counter = checkpoint_counter | |
print("[INFO] Checkpoint Load Success!") | |
else: | |
start_epoch = 0 | |
start_batch_id = 0 | |
counter = 1 | |
print("[INFO] Checkpoint load failed. Training from scratch...") | |
train_iter=[] | |
train_loss=[] | |
IOU=0.65 | |
print("==================================================================") | |
print("[INFO] GENERAL INFORMATION") | |
print("==================================================================") | |
# utils.count_params() | |
print("Total train image:{}".format(len(train_img))) | |
print("Total validate image:{}".format(len(valid_img))) | |
print("Total epoch:{}".format(args.num_epochs)) | |
print("Batch size:{}".format(args.batch_size)) | |
print("Learning rate:{}".format(args.learning_rate)) | |
#print("Checkpoint step:{}".format(args.checkpoint_step)) | |
print("==================================================================") | |
print("[INFO] DATA AUGMENTATION") | |
print("==================================================================") | |
print("h_flip: {}".format(args.h_flip)) | |
print("v_flip: {}".format(args.v_flip)) | |
print("rotate: {}".format(args.rotation)) | |
print("clip size: {}".format(args.clip_size)) | |
print("==================================================================") | |
print("[INFO] TRAINING STARTED") | |
print("==================================================================") | |
loss_tmp = [] | |
for i in range(start_epoch, args.num_epochs): | |
epoch_time=time.time() | |
id_list = np.random.permutation(len(train_img)) | |
batch_pbar = tqdm(range(start_batch_id, num_batches), desc=f"[TRAIN] Epoch {i}") | |
for j in batch_pbar: | |
img_d = [] | |
lab_d = [] | |
for ind in range(args.batch_size): | |
id = id_list[j * args.batch_size + ind] | |
img_d.append(train_img[id]) | |
lab_d.append(train_label[id]) | |
x_batch, y_batch = load_batch(img_d, lab_d) | |
# print(f"[DEBUG] {x_batch[0].shape} {y_batch[0].shape}") | |
# (512, 512, 3) (512, 512, 1) | |
feed_dict = {img: x_batch, | |
label: y_batch, | |
is_training:True} | |
_, loss, pred1 = sess.run([train_step, sigmoid_cross_entropy_loss, pred], feed_dict=feed_dict) | |
loss_tmp.append(loss) | |
if (counter % PRINT_EVERY == 0): | |
tmp = np.median(loss_tmp) | |
train_iter.append(counter) | |
train_loss.append(tmp) | |
#print('Epoch', i, '|Iter', counter, '|Loss', tmp) | |
batch_pbar.set_description(f"[TRAIN] Epoch {i} --- Iter {counter} --- Loss {tmp}") | |
loss_tmp.clear() | |
counter += 1 | |
start_batch_id = 0 | |
history['train']['loss'].append(np.median(loss_tmp)) | |
# print(f'[DEBUG] Time taken for epoch {i}: {time.time() - epoch_time:.3f} seconds') | |
# saver.save(sess, './ckeckpoint_10epoch_new/model.ckpt', global_step=counter) | |
if (i>args.start_valid): | |
if (i-args.start_valid)%args.valid_step==0: | |
val_iou, val_loss = validation() | |
#print(f"[INFO] current val loss: {val_loss}") | |
#print(f"[INFO] last iou valu: {IOU}") | |
#print(f"[INFO] new_iou value: {val_iou}") | |
history['val']['iou'].append(val_iou) | |
history['val']['loss'].append(val_loss) | |
# saving best model based on best IOU score. | |
# Can do based on best val_loss instead too! | |
if val_iou > IOU: | |
print(f"[INFO] Saving best model as checkpoint... val_iou: {val_iou}") | |
saver.save(sess, f'{CHECKPOINTS_DIR}model.ckpt', global_step=counter, write_meta_graph=True) | |
IOU = val_iou | |
saver.save(sess, f'{CHECKPOINTS_DIR}model.ckpt', global_step=counter) | |
def f_iou(predict, label): | |
tp = np.sum(np.logical_and(predict == 1, label == 1)) | |
fp = np.sum(predict==1) | |
fn = np.sum(label == 1) | |
return tp,fp+fn-tp | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment