Last active
August 19, 2022 21:36
-
-
Save smsharma/666acf6da10c3015c452b739742dfadf to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from torch.utils.data import Dataset | |
def load_and_check(filename, memmap=False): | |
""" Load numpy array optionally with memmap | |
""" | |
if memmap: | |
data = np.load(filename, mmap_mode="c") | |
else: | |
data = np.load(filename) | |
return data | |
class NumpyDataset(Dataset): | |
""" Dataset for numpy arrays with explicit memmap support | |
""" | |
def __init__(self, *arrays, dtype=torch.float): | |
self.dtype = dtype | |
self.memmap = [] | |
self.data = [] | |
self.n = None | |
for array in arrays: | |
if self.n is None: | |
self.n = array.shape[0] | |
assert array.shape[0] == self.n | |
if isinstance(array, np.memmap): | |
self.memmap.append(True) | |
self.data.append(array) | |
else: | |
self.memmap.append(False) | |
tensor = torch.from_numpy(array).to(self.dtype) | |
self.data.append(tensor) | |
def __getitem__(self, index): | |
items = [] | |
for memmap, array in zip(self.memmap, self.data): | |
if memmap: | |
tensor = np.array(array[index]) | |
items.append(torch.from_numpy(tensor).to(self.dtype)) | |
else: | |
items.append(array[index]) | |
return tuple(items) | |
def __len__(self): | |
return self.n | |
class NumpyDatasetIndividual(Dataset): | |
""" Load folder of npy files | |
""" | |
def __init__(self, root, num_files): | |
self.root = root | |
self.num_files = num_files | |
def __getitem__(self, index): | |
""" Load tuple of image, params | |
""" | |
image = np.load(self.root + "/image_{}.npy".format(index)) | |
params = np.load(self.root + "/params_{}.npy".format(index)) | |
return image, params | |
def __len__(self): | |
return self.num_files |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment