Skip to content

Instantly share code, notes, and snippets.

@aiwithshekhar
Created December 13, 2019 20:07
Show Gist options
  • Save aiwithshekhar/9d3f693ede3294e543ee9abb03887c75 to your computer and use it in GitHub Desktop.
Save aiwithshekhar/9d3f693ede3294e543ee9abb03887c75 to your computer and use it in GitHub Desktop.
inference on validation dataset
test_dataloader=CarDataloader(df,img_fol,mask_fol,mean,std,'val',1,4)
ckpt_path='/media/shashank/CE7E082A7E080DC1/PycharmProjects/object_detection/model_newloss.pth'
device = torch.device("cuda")
model = smp.Unet("resnet18", encoder_weights=None, classes=1, activation=None)
model.to(device)
model.eval()
state = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
model.load_state_dict(state["state_dict"])
# start prediction
predictions = []
fig, (ax1,ax2)=plt.subplots(1,2,figsize=(15,15))
fig.suptitle('predicted_mask//original_mask')
for i, batch in enumerate(tqdm(test_dataloader)):
images,mask_target = batch
batch_preds = torch.sigmoid(model(images.to(device)))
batch_preds = batch_preds.detach().cpu().numpy()
ax1.imshow(np.squeeze(batch_preds),cmap='gray')
ax2.imshow(np.squeeze(mask_target),cmap='gray')
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment