Skip to content

Instantly share code, notes, and snippets.

@wassname
Last active July 1, 2022 18:05
Show Gist options
  • Save wassname/8ae1f64389c2aaceeb84fcd34c3651c3 to your computer and use it in GitHub Desktop.
Save wassname/8ae1f64389c2aaceeb84fcd34c3651c3 to your computer and use it in GitHub Desktop.
Pytorch random sampler for bigger than memory arrays like dask, zarr, xarray etc that lets you have randomness with the same speed benefits. It chooses a random location, then takes an ordered batch e.g. [[1,2,3],[9,10,11],[4,5,6]]. This way you get the speed of a sequential read.
"""
Pytorch sampler that samples ordered indices from unordered sequences.
Good for use with dask and RNN's, because
1. Dask will slow down if sampling between chunks, so we must do one chunk at a time
2. RNN's need sequences so we must have seqences e.g. 1,2,3
3. But RNN's train better with batches that are uncorrelated so we want each batch to be sequence from a different part of a chunk.
For example, given each chunk is `range(12)`. Our seq_len is 3. We might end up with these indices:
- [[1,2,3],[9,10,11],[4,5,6]]
Usage:
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
sampler=SequenceInChunkSampler(train_dataset, seq_len=batch_size, chunksize=batch_size*100)
)
"""
import torch.utils.data.sampler
import numpy as np
class SequenceInChunkSampler(torch.utils.data.sampler.Sampler):
"""
Samples sequences of elements sequentially, but random sequences in a chunk.
Arguments:
data_source (Dataset): dataset to sample from
seq_len (int): length of sequential sequences
chunksize (int): length of cached data to take random sequences from
url: https://gist.github.com/wassname/8ae1f64389c2aaceeb84fcd34c3651c3
"""
def __init__(self, data_source, seq_len=6, chunksize=6000):
assert chunksize % seq_len == 0, "chunk size should be a multiple of seq_len"
assert len(data_source) > chunksize
self.data_source = data_source
self.seq_len = seq_len
self.chunksize = chunksize
def __iter__(self):
chunk_idxs = np.arange(0, len(self.data_source), self.chunksize)
max_i = len(self.data_source)
print('max_i', max_i)
for chunk_idx in chunk_idxs:
seqs = np.arange(
chunk_idx, min(chunk_idx + self.chunksize, max_i), self.seq_len
)
np.random.shuffle(seqs)
for seq_i in seqs:
for i in np.arange(seq_i, min(seq_i + self.seq_len, max_i)):
yield i
def __len__(self):
return len(self.data_source)
if __name__ == '__main__':
# Test
seq_len = 3
batch_size = 3
chunksize = seq_len * batch_size * 2
X_train = torch.arange(chunksize * 2).unsqueeze(-1)
dataset_train = torch.utils.data.TensorDataset(X_train)
loader_train = torch.utils.data.DataLoader(
dataset_train,
sampler=SequenceInChunkSampler(dataset_train, seq_len=seq_len, chunksize=chunksize),
batch_size=batch_size * seq_len,
drop_last=True
)
x, = next(iter(loader_train))
# View
x = x.numpy().T
xx = np.array(x).reshape(batch_size, seq_len)
print(xx)
# [[15. 16. 17.]
# [ 6. 7. 8.]
# [ 9. 10. 11.]]
# Test
assert (np.diff(xx) == 1).all(), 'does increase by one in sequences'
assert (np.diff(x) != 1).any(), "doesn't increase by 1 at the border of sequences"
assert (xx.max() - xx.min()) <= chunksize, 'total diff should be <= chunk_size'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment