Skip to content

Instantly share code, notes, and snippets.

@NegatioN
Created March 10, 2020 11:33
Show Gist options
  • Save NegatioN/1f63c3a79dfe13b183d413123d37d4fa to your computer and use it in GitHub Desktop.
Save NegatioN/1f63c3a79dfe13b183d413123d37d4fa to your computer and use it in GitHub Desktop.
A simple batch dataloader able to act on cuda tensors
from torch.utils.data import Dataset, DataLoader
import numpy as np
class BatchSampler(Sampler):
def __init__(self, num_samples, batch_size, shuffle=True):
'''
Samples a 1d sequence as batches of indices
:param num_samples: total number of datapoints (1d data sequence) to be sampled from.
'''
self.num_samples = num_samples
self.shuffle = shuffle
self.batch_size = batch_size
self.n_full_batches = int(np.floor(self.num_samples / self.batch_size))
self.last_full_batch_ind = self.num_samples - (self.num_samples % self.batch_size)
def __iter__(self):
items = np.arange(self.num_samples, dtype=np.int32) # int32 most likely fine
if self.shuffle:
items = np.random.permutation(items)
its = items[:self.last_full_batch_ind].reshape(-1, self.batch_size).tolist()
if self.last_full_batch_ind != self.num_samples:
its.append(items[self.last_full_batch_ind:].tolist())
return iter(its)
def __len__(self):
if self.num_samples%self.batch_size == 0:
return self.num_samples // self.batch_size
else:
return self.num_samples // self.batch_size + 1
class MyBatchSet(Dataset):
### TODO IMPL
def __getitem__(self, inds):
#takes list of indices, does whatever
return self.some_cuda_tensor[inds]
train_ds = MyBatchSet(**kwargs)
sampler = BatchSampler(len(MyBatchSet), batch_size, shuffle=True)
train_dl = DataLoader(train_ds,
sampler=sampler,
num_workers=0,
pin_memory=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment