test_dataloader=CarDataloader(df,img_fol,mask_fol,mean,std,'val',1,4)

ckpt_path='/media/shashank/CE7E082A7E080DC1/PycharmProjects/object_detection/model_newloss.pth'

device = torch.device("cuda")
model = smp.Unet("resnet18", encoder_weights=None, classes=1, activation=None)
model.to(device)
model.eval()
state = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
model.load_state_dict(state["state_dict"])

# start prediction
predictions = []
fig, (ax1,ax2)=plt.subplots(1,2,figsize=(15,15))
fig.suptitle('predicted_mask//original_mask')
for i, batch in enumerate(tqdm(test_dataloader)):
    images,mask_target = batch
    batch_preds = torch.sigmoid(model(images.to(device)))
    batch_preds = batch_preds.detach().cpu().numpy()
    ax1.imshow(np.squeeze(batch_preds),cmap='gray')
    ax2.imshow(np.squeeze(mask_target),cmap='gray')
    break