Last active
May 6, 2019 15:22
-
-
Save vishal-keshav/f81e70e46fcafd8de45fbbdf52737689 to your computer and use it in GitHub Desktop.
distributed training example in tensorflow 1.13
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_model(model_name, input_tensor, is_train, pretrained): | |
.... | |
def optimisation(label, logits, param_dict): | |
.... | |
def get_data_provider(dataset_name, dataset_path, param_dict): | |
.... | |
def test_distributed_training(model, dataset, param_dict): | |
import time | |
project_path = os.getcwd() | |
# First get the strategy for training distribution | |
strategy = tf.distribute.MirroredStrategy() | |
# Define the computation that will be taken place on each GPU, with a batch of | |
# examples taken from dataset (each batch will be different, called by get_next()) | |
# Here, we basically want to train the replicated (mirrored model) | |
# The variables, tensors, metrics, summaries etc are all created under strategy scope | |
# so, they are aware of the creating the nodes on each machine. | |
#--------------------------------------------------------------------------- | |
def train_graph_replica(inputs): | |
# Assume that get_next() is called on the iterator | |
input = inputs['image'] | |
label = inputs['label'] | |
# Since train_graph_replica is passed in startegy in the scope, variables are copied | |
model_fn = get_model(model, input, is_train = True, pretrained = False) | |
loss_op, opt_op = optimisation(label, model_fn['feature_logits'], | |
param_dict) | |
# Evaluate the loss after optimization | |
with tf.control_dependencies([opt_op]): | |
return tf.identity(loss_op) | |
#--------------------------------------------------------------------------- | |
with startegy.scope(): | |
# Now, under strategy scope, dataset is created. | |
# All-reduce aggregates tensors across all the devices by adding them up, and | |
# makes them available on each device. This is a synced operations. | |
dp = get_data_provider(dataset, project_path, param_dict) | |
dp.make_distributed(strategy) | |
train_iterator = dp.get_train_dataset(batch_size = 128, | |
shuffle = 1, prefetch = 1) | |
# Whatever happens in train_graph_replica, the tensors are averaged sychronously | |
rep_loss=strategy.experimental_run(train_graph_replica, train_iterator) | |
# This is an aggregator op, using the strategy op, it aggregates some param | |
avg_loss = strategy.reduce(tf.distribute.ReduceOp.MEAN,rep_loss) | |
with tf.Session() as sess: | |
sess.run(tf.initializers.global_variables()) | |
dp.initialize_train(sess) | |
for i in range(100): | |
loss_aggregate = sess.run(avg_loss) | |
print(loss_aggregate) | |
if __name__ == "__main__": | |
model_name = "simple_convnet" | |
dataset_name = "mnist" | |
parameters = {'learning_rate': 0.001} | |
test_distributed_training(model_name, dataset_name, parameters) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment