Skip to content

Instantly share code, notes, and snippets.

@devforfu
Created November 8, 2020 18:59
Show Gist options
  • Save devforfu/4fc0d76a163320b8992a50a913159c6e to your computer and use it in GitHub Desktop.
Save devforfu/4fc0d76a163320b8992a50a913159c6e to your computer and use it in GitHub Desktop.
Create data loaders
def create_data_loaders(
dataset_root: str,
dataset_name: str,
sample_transformer: Callable = None,
target_transformer: Callable = None,
num_workers: int = cpu_count(),
batch_size: int = 32,
download: bool = True,
pin_memory: bool = False,
) -> Tuple[OrderedDict, Dict]:
"""Wraps training and validation datasets with data loaders."""
if dataset_name not in DATASET_FACTORY:
raise ValueError(
f"dataset '{dataset_name}' is not among "
f"available datasets: {list(DATASET_FACTORY)}"
)
factory_options = dict(
root=dataset_root,
transform=sample_transformer,
target_transform=target_transformer,
download=download
)
loaders = OrderedDict()
for subset in ('train', 'valid'):
is_train = subset == 'train'
dataset = DATASET_FACTORY[dataset_name](**factory_options, train=is_train)
loaders[subset] = DataLoader(
dataset=dataset,
num_workers=num_workers,
batch_size=batch_size,
pin_memory=pin_memory,
drop_last=is_train,
shuffle=is_train
)
meta = dict(
encoder=dataset.class_to_idx,
decoder={v: k for k, v in dataset.class_to_idx.items()},
classes=dataset.classes
)
return loaders, meta
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment