Last active
August 8, 2017 18:10
-
-
Save krishpop/f352dcc7beeee5f14ef65ee8fc012f88 to your computer and use it in GitHub Desktop.
Tensorflow dataset class
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Code adapted from TensorFlow source example: | |
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/learn/python/learn/datasets/mnist.py | |
class DataSet: | |
"""Base data set class | |
""" | |
def __init__(self, shuffle=True, labeled=True, **data_dict): | |
assert '_data' in data_dict | |
if labeled: | |
assert '_labels' in data_dict | |
assert data_dict['_data'].shape[0] == data_dict['_labels'].shape[0] | |
self._labeled = labeled | |
self._shuffle = shuffle | |
self.__dict__.update(data_dict) | |
self._num_samples = self._data.shape[0] | |
self._index_in_epoch = 0 | |
if self._shuffle: | |
self._shuffle_data() | |
def __len__(self): | |
return len(self._data) | |
@property | |
def index_in_epoch(self): | |
return self._index_in_epoch | |
@property | |
def num_samples(self): | |
return self._num_samples | |
@property | |
def data(self): | |
return self._data | |
@property | |
def labels(self): | |
return self._labels | |
@property | |
def labeled(self): | |
return self._labeled | |
@property | |
def test_data(self): | |
return self._test_data | |
@property | |
def test_labels(self): | |
return self._test_labels | |
@classmethod | |
def load(cls, filename): | |
data_dict = np.load(filename) | |
return cls(**data_dict) | |
def save(self, filename): | |
data_dict = self.__dict__ | |
np.savez_compressed(filename, **data_dict) | |
def _shuffle_data(self): | |
shuffled_idx = np.arange(self._num_samples) | |
np.random.shuffle(shuffled_idx) | |
self._data = self._data[shuffled_idx] | |
if self._labeled: | |
self._labels = self._labels[shuffled_idx] | |
def next_batch(self, batch_size): | |
assert batch_size <= self._num_samples | |
start = self._index_in_epoch | |
if start + batch_size > self._num_samples: | |
data_batch = self._data[start:] | |
if self._labeled: | |
labels_batch = self._labels[start:] | |
remaining = batch_size - (self._num_samples - start) | |
if self._shuffle: | |
self._shuffle_data() | |
start = 0 | |
data_batch = np.concatenate([data_batch, self._data[:remaining]], | |
axis=0) | |
if self._labeled: | |
labels_batch = np.concatenate([labels_batch, | |
self._labels[:remaining]], | |
axis=0) | |
self._index_in_epoch = remaining | |
else: | |
data_batch = self._data[start:start + batch_size] | |
if self._labeled: | |
labels_batch = self._labels[start:start + batch_size] | |
self._index_in_epoch = start + batch_size | |
batch = (data_batch, labels_batch) if self._labeled else data_batch | |
return batch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn import datasets | |
from sklearn.model_selection import train_test_split | |
iris = datasets.load_iris() | |
X = iris.data | |
y = iris.target | |
train_X, train_y, test_X, test_y = train_test_split(X, y, train_size=0.9) | |
data_dict = { | |
'_data': train_X, | |
'_labels': train_y, | |
'_test_data': test_X, | |
'_test_labels': test_y | |
} | |
iris_data = Dataset(**data_dict) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment