Created
August 10, 2023 17:42
-
-
Save MischaPanch/ff790dd72ebb241006e186f8d1a26e3d to your computer and use it in GitHub Desktop.
Accelerated torch dataset and dataloader
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
from torch.utils.data import Dataset, Subset, TensorDataset | |
from typing import ( | |
Callable, | |
Iterator, | |
Literal, | |
Optional, | |
Sequence, | |
Tuple, | |
TypeVar, | |
Union, | |
) | |
def get_batch_boundaries( | |
batch_size: int, | |
len_data: int, | |
last_batch: Literal["drop", "merge", "keep"] = "merge", | |
): | |
"""Get the boundaries of batches for a given batch size and data length. | |
:param batch_size: the size of each batch | |
:param len_data: the length of the data | |
:param last_batch: one of "drop", "merge", or "keep". | |
- "drop": drop the last batch if it is smaller than batch_size | |
- "merge": merge the last batch with the previous batch | |
- "keep": keep the last batch as is, even if it is smaller than batch_size | |
:return: a numpy array of batch boundaries | |
""" | |
if batch_size >= len_data: | |
return np.array([0, len_data]) | |
batch_boundaries = np.arange(0, len_data + 1, batch_size) | |
if len_data % batch_size == 0 or last_batch == "drop": | |
return batch_boundaries | |
elif last_batch == "merge": | |
batch_boundaries[-1] = len_data | |
elif last_batch == "keep": | |
batch_boundaries = np.append(batch_boundaries, len_data) | |
else: | |
raise ValueError( | |
f"last_batch must be one of 'drop', 'merge', or 'keep', " | |
f"but got {last_batch}" | |
) | |
return batch_boundaries | |
class Accelerated2DTensorDataset(Dataset): | |
""" | |
Same logic as torch.utils.data.TensorDataset but avoids some overhead by | |
retrieving from a single tensor and then slicing it, instead of retrieving from | |
multiple tensors in a loop. Currently only supports up-to 2D tensors. | |
""" | |
def __init__(self, *tensors: torch.Tensor) -> None: | |
""" | |
:param tensors: tensors to be stacked. All tensors must have the same | |
first dimension, and up to 2 dimensions. | |
""" | |
self._len = tensors[0].shape[0] | |
boundaries = [0] | |
unsqueezed_tensors = [] | |
for tensor in tensors: | |
if tensor.shape[0] != self._len: | |
raise ValueError( | |
"All tensors must have the same first dimension, " | |
f"but got {tensor.shape[0]} and {self._len}" | |
) | |
if len(tensor.shape) == 1: | |
tensor = tensor.unsqueeze(1) | |
if len(tensor.shape) != 2: | |
raise ValueError( | |
"All tensors must have up to 2 dimensions, " | |
f"but got {len(tensor.shape)}" | |
) | |
unsqueezed_tensors.append(tensor) | |
boundaries.append(boundaries[-1] + tensor.shape[1]) | |
self._stacked_tensors = torch.hstack(unsqueezed_tensors) | |
self._slices = [ | |
slice(low, high) for low, high in zip(boundaries[:-1], boundaries[1:]) | |
] | |
def __getitem__(self, index) -> Tuple[torch.Tensor, ...]: | |
return tuple(self._stacked_tensors[index, sl] for sl in self._slices) | |
def __len__(self) -> int: | |
return self._len | |
SupportsBatching = Union[ | |
TensorDataset, | |
Accelerated2DTensorDataset, | |
Subset, | |
torch.Tensor, | |
np.ndarray, | |
Sequence, | |
] | |
T = TypeVar("T") | |
class BatchDataLoader: | |
def __init__( | |
self, | |
data: SupportsBatching, | |
batch_size: int, | |
shuffle: bool = False, | |
last_batch: Literal["drop", "merge", "keep"] = "merge", | |
collate_fn: Optional[Callable[[Union[torch.Tensor, np.ndarray]], T]] = None, | |
) -> None: | |
"""A simple data loader that returns batches of data. | |
:param data: the data to be loaded. If tensor-based, the batches will be | |
tensors, otherwise they will be numpy arrays. | |
:param batch_size: the size of each batch | |
:param shuffle: whether to shuffle the data before batching | |
:param last_batch: one of "drop", "merge", or "keep". | |
- "drop": drop the last batch if it is smaller than batch_size | |
- "merge": merge the last batch with the previous batch | |
- "keep": keep the last batch as is, even if it is smaller than batch_size | |
:param collate_fn: a function to apply to each batch before returning it | |
""" | |
if isinstance(data, Sequence): | |
data = np.array(data) | |
# not pretty nor robust, but hopefully this code won't be around for long anyway | |
while isinstance(data, Subset): | |
data = data.dataset | |
self._data = data | |
self.batch_size = batch_size | |
self.shuffle = shuffle | |
self.last_batch = last_batch | |
self._boundary_idxs = get_batch_boundaries( | |
batch_size, len(data), last_batch=last_batch | |
) | |
self._num_batches = len(self._boundary_idxs) - 1 | |
self.collate_fn = collate_fn or (lambda x: x) | |
# TODO: the generic annotation here is probably incorrect | |
def __iter__(self) -> Iterator[Union[np.ndarray, torch.Tensor, T]]: | |
if self.shuffle: | |
self._shuffle_data() | |
for lower, upper in zip(self._boundary_idxs[:-1], self._boundary_idxs[1:]): | |
yield self.collate_fn(self._data[lower:upper]) | |
def _shuffle_data(self): | |
data_type = type(self._data) | |
self._data = self._data[np.random.permutation(len(self._data))] | |
if issubclass(data_type, (TensorDataset, Accelerated2DTensorDataset)): | |
# retrieving data from these types changes the type, so we change it back | |
self._data = data_type(*self._data) | |
def __len__(self) -> int: | |
return self._num_batches |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment