Skip to content

Instantly share code, notes, and snippets.

@rozeappletree
Created January 7, 2021 16:59
Show Gist options
  • Save rozeappletree/fe3161dfd234086e831f3aab1628a9e5 to your computer and use it in GitHub Desktop.
Save rozeappletree/fe3161dfd234086e831f3aab1628a9e5 to your computer and use it in GitHub Desktop.
def train():
tf.global_variables_initializer().run()
could_load, checkpoint_counter = load()
if could_load:
start_epoch = (int)(checkpoint_counter / num_batches)
start_batch_id = checkpoint_counter - start_epoch * num_batches
counter = checkpoint_counter
print("[INFO] Checkpoint Load Success!")
else:
start_epoch = 0
start_batch_id = 0
counter = 1
print("[INFO] Checkpoint load failed. Training from scratch...")
train_iter=[]
train_loss=[]
IOU=0.65
print("==================================================================")
print("[INFO] GENERAL INFORMATION")
print("==================================================================")
# utils.count_params()
print("Total train image:{}".format(len(train_img)))
print("Total validate image:{}".format(len(valid_img)))
print("Total epoch:{}".format(args.num_epochs))
print("Batch size:{}".format(args.batch_size))
print("Learning rate:{}".format(args.learning_rate))
#print("Checkpoint step:{}".format(args.checkpoint_step))
print("==================================================================")
print("[INFO] DATA AUGMENTATION")
print("==================================================================")
print("h_flip: {}".format(args.h_flip))
print("v_flip: {}".format(args.v_flip))
print("rotate: {}".format(args.rotation))
print("clip size: {}".format(args.clip_size))
print("==================================================================")
print("[INFO] TRAINING STARTED")
print("==================================================================")
loss_tmp = []
for i in range(start_epoch, args.num_epochs):
epoch_time=time.time()
id_list = np.random.permutation(len(train_img))
batch_pbar = tqdm(range(start_batch_id, num_batches), desc=f"[TRAIN] Epoch {i}")
for j in batch_pbar:
img_d = []
lab_d = []
for ind in range(args.batch_size):
id = id_list[j * args.batch_size + ind]
img_d.append(train_img[id])
lab_d.append(train_label[id])
x_batch, y_batch = load_batch(img_d, lab_d)
# print(f"[DEBUG] {x_batch[0].shape} {y_batch[0].shape}")
# (512, 512, 3) (512, 512, 1)
feed_dict = {img: x_batch,
label: y_batch,
is_training:True}
_, loss, pred1 = sess.run([train_step, sigmoid_cross_entropy_loss, pred], feed_dict=feed_dict)
loss_tmp.append(loss)
if (counter % PRINT_EVERY == 0):
tmp = np.median(loss_tmp)
train_iter.append(counter)
train_loss.append(tmp)
#print('Epoch', i, '|Iter', counter, '|Loss', tmp)
batch_pbar.set_description(f"[TRAIN] Epoch {i} --- Iter {counter} --- Loss {tmp}")
loss_tmp.clear()
counter += 1
start_batch_id = 0
history['train']['loss'].append(np.median(loss_tmp))
# print(f'[DEBUG] Time taken for epoch {i}: {time.time() - epoch_time:.3f} seconds')
# saver.save(sess, './ckeckpoint_10epoch_new/model.ckpt', global_step=counter)
if (i>args.start_valid):
if (i-args.start_valid)%args.valid_step==0:
val_iou, val_loss = validation()
#print(f"[INFO] current val loss: {val_loss}")
#print(f"[INFO] last iou valu: {IOU}")
#print(f"[INFO] new_iou value: {val_iou}")
history['val']['iou'].append(val_iou)
history['val']['loss'].append(val_loss)
# saving best model based on best IOU score.
# Can do based on best val_loss instead too!
if val_iou > IOU:
print(f"[INFO] Saving best model as checkpoint... val_iou: {val_iou}")
saver.save(sess, f'{CHECKPOINTS_DIR}model.ckpt', global_step=counter, write_meta_graph=True)
IOU = val_iou
saver.save(sess, f'{CHECKPOINTS_DIR}model.ckpt', global_step=counter)
def f_iou(predict, label):
tp = np.sum(np.logical_and(predict == 1, label == 1))
fp = np.sum(predict==1)
fn = np.sum(label == 1)
return tp,fp+fn-tp
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment