Created
May 4, 2023 08:34
-
-
Save maharjun/511fce91b641e5f24717099be981260a to your computer and use it in GitHub Desktop.
Pytorch Utilities
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Utility functions for PyTorch data manipulation and device handling. | |
This module provides various utility functions for handling PyTorch datasets, tensors, and devices. Functions include: | |
- Splitting datasets into train and test sets | |
- Concatenating multiple datasets | |
- Generating random batches from a dataset | |
- Converting data to a specific device | |
- Retrieving the default device | |
- Getting GPU device names | |
- Getting the name of a specific device | |
""" | |
############################################################################### | |
# BSD 3-Clause License | |
# | |
# Copyright (c) 2023, maharjun | |
# | |
# Redistribution and use in source and binary forms, with or without | |
# modification, are permitted provided that the following conditions are met: | |
# | |
# 1. Redistributions of source code must retain the above copyright notice, this | |
# list of conditions and the following disclaimer. | |
# | |
# 2. Redistributions in binary form must reproduce the above copyright notice, | |
# this list of conditions and the following disclaimer in the documentation | |
# and/or other materials provided with the distribution. | |
# | |
# 3. Neither the name of the copyright holder nor the names of its | |
# contributors may be used to endorse or promote products derived from | |
# this software without specific prior written permission. | |
# | |
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | |
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
############################################################################### | |
from typing import List | |
import torch | |
def train_test_data_split(dataset: torch.utils.data.Dataset, train_fraction: float, generator: torch.Generator): | |
""" | |
Splits the given dataset into train and test sets based on the specified train_fraction. | |
Parameters | |
---------- | |
dataset: torch.utils.data.Dataset | |
Dataset to be split. | |
train_fraction: float | |
Fraction of the dataset to be used for training. Should be between 0 and 1. | |
generator: torch.Generator | |
Random number generator. | |
Returns | |
------- | |
tuple | |
A tuple containing two datasets, the train and test sets. | |
""" | |
n_all = len(dataset) | |
shuffle_inds = torch.randperm(n_all, generator=generator) | |
n_train = int(n_all * train_fraction) | |
return (dataset.__class__(*dataset[shuffle_inds[:n_train]]), | |
dataset.__class__(*dataset[shuffle_inds[n_train:]])) | |
def concatenate_data(datasets: List[torch.utils.data.Dataset]): | |
""" | |
Concatenates the given list of datasets. | |
Parameters | |
---------- | |
datasets: List[torch.utils.data.Dataset] | |
List of datasets to concatenate. | |
Returns | |
------- | |
torch.utils.data.Dataset | |
The concatenated dataset. | |
""" | |
assert len(datasets) > 0, "Atleast one dataset should be given to concatenate" | |
all_data_tuples = [dset[:] for dset in datasets] | |
data_tuple_len = len(all_data_tuples[0]) | |
cat_data_tuple = tuple(torch.cat([x[i] for x in all_data_tuples]) | |
for i in range(data_tuple_len)) | |
return datasets[0].__class__(*cat_data_tuple) | |
def random_batch_input(dataset: torch.utils.data.Dataset, batch_size: int, generator: torch.Generator): | |
""" | |
Generator that yields batches of data from the given dataset with random shuffling. | |
This is a generator that does something that torch can't: namely shuffle each epoch while ensuring samples don't | |
overlap over epochs while maintaining a constant batch size. | |
Parameters | |
---------- | |
dataset: torch.utils.data.Dataset | |
Dataset to generate batches from. | |
batch_size: int | |
Number of samples per batch. | |
generator: torch.Generator | |
Random number generator. | |
Yields | |
------ | |
tuple | |
A tuple containing a data batch and a boolean flag indicating if the epoch has ended. | |
""" | |
assert len(dataset) > 0, "Must specify at-least one tensor to batch" | |
num_data = len(dataset) | |
assert batch_size <= num_data, "The batch size must be less than or equal to the size of the dataset" | |
current_cursor = 0 | |
shuffle_inds = torch.randperm(num_data, device=generator.device, generator=generator) | |
epoch_ended = False | |
while True: | |
end_cursor = current_cursor + batch_size | |
end_cursor_first = min(end_cursor, num_data) | |
batch_indices = shuffle_inds[current_cursor:end_cursor_first] | |
if end_cursor >= num_data: | |
# A piece of logic that ensures that the same elements in x_data are not taken again | |
perm1 = torch.randperm(current_cursor, device=generator.device, generator=generator) | |
new_shuffle_inds = shuffle_inds.detach().clone() | |
new_shuffle_inds[:current_cursor] = new_shuffle_inds[:current_cursor][perm1] | |
perm2 = torch.randperm(2*num_data - end_cursor, device=generator.device, generator=generator) | |
new_shuffle_inds[end_cursor-num_data:] = new_shuffle_inds[end_cursor-num_data:][perm2] | |
batch_indices_second = new_shuffle_inds[:end_cursor-num_data] | |
batch_indices = torch.cat([batch_indices, batch_indices_second], dim=0) | |
shuffle_inds = new_shuffle_inds | |
epoch_ended = True | |
data_batch = dataset[batch_indices] | |
yield data_batch, epoch_ended | |
epoch_ended = False | |
current_cursor = end_cursor % num_data | |
def convert_data_to_device(data, device: torch.device): | |
""" | |
Converts the data (dict or tensor) to the specified device. | |
Parameters | |
---------- | |
data: dict or torch.Tensor | |
Data to be converted, can be a dict or tensor. | |
device: torch.device | |
Target device. | |
Returns | |
------- | |
Data converted to the target device. | |
""" | |
if isinstance(data, dict): | |
return {key: convert_data_to_device(val, device) for key, val in data.items()} | |
elif torch.is_tensor(data): | |
return data.to(device=device) | |
else: | |
raise TypeError("Require either tensor or dict to convert") | |
def get_default_device(): | |
""" | |
Returns the default device available for PyTorch. | |
Returns | |
------- | |
torch.device | |
Default device for PyTorch. | |
""" | |
return torch.as_tensor([0., 1.0]).device | |
def get_gpu_name_if_available(gpu_index=None): | |
""" | |
Returns the GPU device name if available, otherwise returns 'cpu'. | |
Parameters | |
---------- | |
gpu_index: int, optional | |
Index of the GPU device to use, if available. | |
Returns | |
------- | |
str | |
GPU device name or 'cpu' if GPU is not available. | |
""" | |
use_cuda = torch.cuda.is_available() | |
if use_cuda and gpu_index is not None: | |
device_name = 'cuda:{}'.format(gpu_index) | |
elif use_cuda: | |
device_name = 'cuda' | |
else: | |
device_name = 'cpu' | |
return device_name | |
def get_device_name(device: torch.device): | |
""" | |
Returns the name of the specified device object. | |
Parameters | |
---------- | |
device: torch.device | |
The device object. | |
Returns | |
------- | |
str | |
The name of the device. | |
""" | |
assert isinstance(device, torch.device), "device must be a torch.device" | |
if device.index: | |
return f'{device.type}:{device.index}' | |
else: | |
return device.type |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment