Skip to content

Instantly share code, notes, and snippets.

@mingrui
Created May 18, 2018 09:40
Show Gist options
  • Save mingrui/16c16452371800855ca3f7e4d70efc99 to your computer and use it in GitHub Desktop.
Save mingrui/16c16452371800855ca3f7e4d70efc99 to your computer and use it in GitHub Desktop.
test dice
# source: https://github.com/shreyaspadhy/UNet-Zoo
def test(train_accuracy=False, save_output=False):
test_loss = 0
correct = 0
if train_accuracy:
loader = train_loader
else:
loader = test_loader
for batch_idx, (image, mask) in tqdm(enumerate(loader)):
if args.cuda:
image, mask = image.cuda(), mask.cuda()
image, mask = Variable(image, volatile=True), Variable(
mask, volatile=True)
output = model(image)
test_loss += criterion(output, mask).data[0]
output.data.round_()
if save_output and (not train_accuracy):
np.save('./npy-files/out-files/{}-unetsmall-batch-{}-outs.npy'.format(args.save,
batch_idx),
output.data.byte().cpu().numpy())
np.save('./npy-files/out-files/{}-unetsmall-batch-{}-masks.npy'.format(args.save,
batch_idx),
mask.data.byte().cpu().numpy())
np.save('./npy-files/out-files/{}-unetsmall-batch-{}-images.npy'.format(args.save,
batch_idx),
image.data.float().cpu().numpy())
if save_output and train_accuracy:
np.save('./npy-files/out-files/{}-unetsmall-train-batch-{}-outs.npy'.format(args.save,
batch_idx),
output.data.byte().cpu().numpy())
np.save('./npy-files/out-files/{}-unetsmall-train-batch-{}-masks.npy'.format(args.save,
batch_idx),
mask.data.byte().cpu().numpy())
np.save('./npy-files/out-files/{}-unetsmall-train-batch-{}-images.npy'.format(args.save,
batch_idx),
image.data.float().cpu().numpy())
# Average Dice Coefficient
test_loss /= len(loader)
if train_accuracy:
print('\nTraining Set: Average DICE Coefficient: {:.4f})\n'.format(
test_loss))
else:
print('\nTest Set: Average DICE Coefficient: {:.4f})\n'.format(
test_loss))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment