Skip to content

Instantly share code, notes, and snippets.

@omarsar
Created October 15, 2018 09:18
Show Gist options
  • Save omarsar/dcaca52151347f97cd1ecd3a40b02ade to your computer and use it in GitHub Desktop.
Save omarsar/dcaca52151347f97cd1ecd3a40b02ade to your computer and use it in GitHub Desktop.
for epoch in tqdm(range(1, num_epochs+1)):
start_time = time.time()
scheduler.step()
lr = scheduler.get_lr()[0]
model.train()
train_loss_total = 0.0
num_steps = 0
### Training
for i, batch in enumerate(train_loader):
input_samples, gt_samples = batch["input"], batch["gt"]
var_input = input_samples.cuda()
var_gt = gt_samples.cuda(async=True)
preds = model(var_input)
loss = mt_losses.dice_loss(preds, var_gt)
train_loss_total += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
num_steps += 1
if epoch % 5 == 0:
grid_img = vutils.make_grid(input_samples,
normalize=True,
scale_each=True)
grid_img = vutils.make_grid(preds.data.cpu(),
normalize=True,
scale_each=True)
grid_img = vutils.make_grid(gt_samples,
normalize=True,
scale_each=True)
train_loss_total_avg = train_loss_total / num_steps
model.eval()
val_loss_total = 0.0
num_steps = 0
train_acc = accuracy(preds.cpu().detach().numpy(),
var_gt.cpu().detach().numpy())
metric_fns = [mt_metrics.dice_score,
mt_metrics.hausdorff_score,
mt_metrics.precision_score,
mt_metrics.recall_score,
mt_metrics.specificity_score,
mt_metrics.intersection_over_union,
mt_metrics.accuracy_score]
metric_mgr = mt_metrics.MetricManager(metric_fns)
### Validating
for i, batch in enumerate(val_loader):
input_samples, gt_samples = batch["input"], batch["gt"]
with torch.no_grad():
var_input = input_samples.cuda()
var_gt = gt_samples.cuda(async=True)
preds = model(var_input)
loss = mt_losses.dice_loss(preds, var_gt)
val_loss_total += loss.item()
# Metrics computation
gt_npy = gt_samples.numpy().astype(np.uint8)
gt_npy = gt_npy.squeeze(axis=1)
preds = preds.data.cpu().numpy()
preds = threshold_predictions(preds)
preds = preds.astype(np.uint8)
preds = preds.squeeze(axis=1)
metric_mgr(preds, gt_npy)
num_steps += 1
metrics_dict = metric_mgr.get_results()
metric_mgr.reset()
val_loss_total_avg = val_loss_total / num_steps
print('\nTrain loss: {:.4f}, Training Accuracy: {:.4f} '.format(train_loss_total_avg, train_acc))
print('Val Loss: {:.4f}, Validation Accuracy: {:.4f} '.format(val_loss_total_avg, metrics_dict["accuracy_score"]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment