Created
July 9, 2016 13:31
-
-
Save therne/3b6f1db728b78d6125647229884a574b to your computer and use it in GitHub Desktop.
Batch data loader for minibatch training
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import copy | |
import numpy as np | |
class DataSet: | |
def __init__(self, data, batch_size=1, shuffle=True, name="dataset"): | |
assert batch_size <= len(data), "batch size cannot be greater than data size." | |
self.name = name | |
self.data = data | |
self.batch_size = batch_size | |
self.shuffle = shuffle | |
self.count = len(self.data) | |
self.setup() | |
def setup(self): | |
self.indices = list(range(self.count)) # used in shuffling | |
self.current_index = 0 | |
self.num_batches = int(self.count / self.batch_size) | |
self.reset() | |
def next_batch(self): | |
""" Get next batch data. | |
:return: data of batch_size | |
""" | |
assert self.has_next_batch(), "End of epoch. Call 'complete_epoch()' to reset." | |
from_, to = self.current_index, self.current_index + self.batch_size | |
cur_idxs = self.indices[from_:to] | |
batch = [self.data[i] for i in cur_idxs] | |
self.current_index += self.batch_size | |
return batch | |
def has_next_batch(self): | |
return self.current_index + self.batch_size <= self.count | |
def split_dataset(self, split_ratio): | |
""" Splits a data set by split_ratio. | |
(ex: split_ratio = 0.3 -> this set (70%) and splitted (30%)) | |
:param split_ratio: ratio of train data | |
:return: val_set | |
""" | |
end_index = int(self.count * (1 - split_ratio)) | |
assert self.count - end_index >= self.batch_size, "splitted data size cannot be smaller than batch size." | |
# do not (deep) copy data - just modify index list! | |
splitted = copy.copy(self) | |
splitted.count = self.count - end_index | |
splitted.indexes = list(range(end_index, self.count)) | |
splitted.num_batches = int(splitted.count / splitted.batch_size) | |
self.count = end_index | |
self.setup() | |
return splitted | |
def reset(self): | |
self.current_index = 0 | |
if self.shuffle: | |
np.random.shuffle(self.indices) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment