Last active
July 1, 2022 18:05
-
-
Save wassname/8ae1f64389c2aaceeb84fcd34c3651c3 to your computer and use it in GitHub Desktop.
Pytorch random sampler for bigger than memory arrays like dask, zarr, xarray etc that lets you have randomness with the same speed benefits. It chooses a random location, then takes an ordered batch e.g. [[1,2,3],[9,10,11],[4,5,6]]. This way you get the speed of a sequential read.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Pytorch sampler that samples ordered indices from unordered sequences. | |
Good for use with dask and RNN's, because | |
1. Dask will slow down if sampling between chunks, so we must do one chunk at a time | |
2. RNN's need sequences so we must have seqences e.g. 1,2,3 | |
3. But RNN's train better with batches that are uncorrelated so we want each batch to be sequence from a different part of a chunk. | |
For example, given each chunk is `range(12)`. Our seq_len is 3. We might end up with these indices: | |
- [[1,2,3],[9,10,11],[4,5,6]] | |
Usage: | |
train_loader = torch.utils.data.DataLoader( | |
train_dataset, | |
batch_size=batch_size, | |
sampler=SequenceInChunkSampler(train_dataset, seq_len=batch_size, chunksize=batch_size*100) | |
) | |
""" | |
import torch.utils.data.sampler | |
import numpy as np | |
class SequenceInChunkSampler(torch.utils.data.sampler.Sampler): | |
""" | |
Samples sequences of elements sequentially, but random sequences in a chunk. | |
Arguments: | |
data_source (Dataset): dataset to sample from | |
seq_len (int): length of sequential sequences | |
chunksize (int): length of cached data to take random sequences from | |
url: https://gist.github.com/wassname/8ae1f64389c2aaceeb84fcd34c3651c3 | |
""" | |
def __init__(self, data_source, seq_len=6, chunksize=6000): | |
assert chunksize % seq_len == 0, "chunk size should be a multiple of seq_len" | |
assert len(data_source) > chunksize | |
self.data_source = data_source | |
self.seq_len = seq_len | |
self.chunksize = chunksize | |
def __iter__(self): | |
chunk_idxs = np.arange(0, len(self.data_source), self.chunksize) | |
max_i = len(self.data_source) | |
print('max_i', max_i) | |
for chunk_idx in chunk_idxs: | |
seqs = np.arange( | |
chunk_idx, min(chunk_idx + self.chunksize, max_i), self.seq_len | |
) | |
np.random.shuffle(seqs) | |
for seq_i in seqs: | |
for i in np.arange(seq_i, min(seq_i + self.seq_len, max_i)): | |
yield i | |
def __len__(self): | |
return len(self.data_source) | |
if __name__ == '__main__': | |
# Test | |
seq_len = 3 | |
batch_size = 3 | |
chunksize = seq_len * batch_size * 2 | |
X_train = torch.arange(chunksize * 2).unsqueeze(-1) | |
dataset_train = torch.utils.data.TensorDataset(X_train) | |
loader_train = torch.utils.data.DataLoader( | |
dataset_train, | |
sampler=SequenceInChunkSampler(dataset_train, seq_len=seq_len, chunksize=chunksize), | |
batch_size=batch_size * seq_len, | |
drop_last=True | |
) | |
x, = next(iter(loader_train)) | |
# View | |
x = x.numpy().T | |
xx = np.array(x).reshape(batch_size, seq_len) | |
print(xx) | |
# [[15. 16. 17.] | |
# [ 6. 7. 8.] | |
# [ 9. 10. 11.]] | |
# Test | |
assert (np.diff(xx) == 1).all(), 'does increase by one in sequences' | |
assert (np.diff(x) != 1).any(), "doesn't increase by 1 at the border of sequences" | |
assert (xx.max() - xx.min()) <= chunksize, 'total diff should be <= chunk_size' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment