Skip to content

Instantly share code, notes, and snippets.

@amrakm
Created September 7, 2022 20:50
Show Gist options
  • Save amrakm/0df828579eb1ccef2b6933140e2428da to your computer and use it in GitHub Desktop.
Save amrakm/0df828579eb1ccef2b6933140e2428da to your computer and use it in GitHub Desktop.
torch_data_loader_w_corrupted_imgs.md

source: pytorch/pytorch#1137 (comment)

Follow these steps in order to handle corrupted images:

Return None in the getitem() if the image is corrupted

def __getitem__(self, idx):
    try:
        img, label = load_img(idx)
    except:
        return None
    return [img, label]

Filter the None values in the collate_fn()

def collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))
    return torch.utils.data.dataloader.default_collate(batch)

Pass the collate_fn() to the DataLoader()

train_loader = DataLoader(train_dataset, collate_fn=collate_fn, **kwargs)
test_loader = DataLoader(test_dataset, collate_fn=collate_fn, **kwargs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment