source: pytorch/pytorch#1137 (comment)
Follow these steps in order to handle corrupted images:
Return None in the getitem() if the image is corrupted
def __getitem__(self, idx):
try:
img, label = load_img(idx)
except:
return None
return [img, label]
Filter the None values in the collate_fn()
def collate_fn(batch):
batch = list(filter(lambda x: x is not None, batch))
return torch.utils.data.dataloader.default_collate(batch)
Pass the collate_fn() to the DataLoader()
train_loader = DataLoader(train_dataset, collate_fn=collate_fn, **kwargs)
test_loader = DataLoader(test_dataset, collate_fn=collate_fn, **kwargs)