Skip to content

Instantly share code, notes, and snippets.

@smsharma
Last active August 19, 2022 21:36
Show Gist options
  • Save smsharma/666acf6da10c3015c452b739742dfadf to your computer and use it in GitHub Desktop.
Save smsharma/666acf6da10c3015c452b739742dfadf to your computer and use it in GitHub Desktop.
import numpy as np
from torch.utils.data import Dataset
def load_and_check(filename, memmap=False):
""" Load numpy array optionally with memmap
"""
if memmap:
data = np.load(filename, mmap_mode="c")
else:
data = np.load(filename)
return data
class NumpyDataset(Dataset):
""" Dataset for numpy arrays with explicit memmap support
"""
def __init__(self, *arrays, dtype=torch.float):
self.dtype = dtype
self.memmap = []
self.data = []
self.n = None
for array in arrays:
if self.n is None:
self.n = array.shape[0]
assert array.shape[0] == self.n
if isinstance(array, np.memmap):
self.memmap.append(True)
self.data.append(array)
else:
self.memmap.append(False)
tensor = torch.from_numpy(array).to(self.dtype)
self.data.append(tensor)
def __getitem__(self, index):
items = []
for memmap, array in zip(self.memmap, self.data):
if memmap:
tensor = np.array(array[index])
items.append(torch.from_numpy(tensor).to(self.dtype))
else:
items.append(array[index])
return tuple(items)
def __len__(self):
return self.n
class NumpyDatasetIndividual(Dataset):
""" Load folder of npy files
"""
def __init__(self, root, num_files):
self.root = root
self.num_files = num_files
def __getitem__(self, index):
""" Load tuple of image, params
"""
image = np.load(self.root + "/image_{}.npy".format(index))
params = np.load(self.root + "/params_{}.npy".format(index))
return image, params
def __len__(self):
return self.num_files
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment