test_dataloader=CarDataloader(df,img_fol,mask_fol,mean,std,'val',1,4) ckpt_path='/media/shashank/CE7E082A7E080DC1/PycharmProjects/object_detection/model_newloss.pth' device = torch.device("cuda") model = smp.Unet("resnet18", encoder_weights=None, classes=1, activation=None) model.to(device) model.eval() state = torch.load(ckpt_path, map_location=lambda storage, loc: storage) model.load_state_dict(state["state_dict"]) # start prediction predictions = [] fig, (ax1,ax2)=plt.subplots(1,2,figsize=(15,15)) fig.suptitle('predicted_mask//original_mask') for i, batch in enumerate(tqdm(test_dataloader)): images,mask_target = batch batch_preds = torch.sigmoid(model(images.to(device))) batch_preds = batch_preds.detach().cpu().numpy() ax1.imshow(np.squeeze(batch_preds),cmap='gray') ax2.imshow(np.squeeze(mask_target),cmap='gray') break