Last active
June 14, 2018 16:16
-
-
Save shu-yusa/0423efb98996e02a822a6cddb4dd1c18 to your computer and use it in GitHub Desktop.
[TensorFlow] MirroredStrategyを用いて複数GPU計算を行う ref: https://qiita.com/shu-yusa/items/e93e934a14849541de78
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
distribution = tf.contrib.distribute.MirroredStrategy() | |
config = tf.estimator.RunConfig(train_distribute=distribution) | |
classifier = tf.estimator.Estimator(model_fn=model_fn, config=config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
ValueError: dataset_fn() must return a tf.data.Dataset when using a DistributionStrategy. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class InputFnProvider: | |
def __init__(self, train_batch_size): | |
self.train_batch_size = train_batch_size | |
self.__load_data() | |
def __load_data(self): | |
# Load training and eval data | |
mnist = tf.contrib.learn.datasets.load_dataset("mnist") | |
self.train_data = mnist.train.images # Returns np.array | |
self.train_labels = np.asarray(mnist.train.labels, dtype=np.int32) | |
self.eval_data = mnist.test.images # Returns np.array | |
self.eval_labels = np.asarray(mnist.test.labels, dtype=np.int32) | |
def train_input_fn(self): | |
"""An input function for training""" | |
# Shuffle, repeat, and batch the examples. | |
dataset = tf.data.Dataset.from_tensor_slices(({"x": self.train_data}, self.train_labels)) | |
dataset = dataset.shuffle(1000).repeat().batch(self.train_batch_size) | |
return dataset | |
def eval_input_fn(self): | |
"""An input function for evaluation or prediction""" | |
dataset = tf.data.Dataset.from_tensor_slices(({"x": self.eval_data}, self.eval_labels)) | |
dataset = dataset.batch(1) | |
return dataset | |
# (中略) | |
# Train the model | |
mnist_classifier.train( | |
input_fn=input_fn_provider.train_input_fn, | |
steps=10000) | |
# Evaluate the model and print results | |
eval_results = mnist_classifier.evaluate(input_fn=input_fn_provider.eval_input_fn) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
return dataset.make_one_shot_iterator().get_next() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment