Skip to content

Instantly share code, notes, and snippets.

@pinzhenx
Last active August 16, 2021 07:17
Show Gist options
  • Save pinzhenx/fb61fdee1eec4ac312f985a97376a06d to your computer and use it in GitHub Desktop.
Save pinzhenx/fb61fdee1eec4ac312f985a97376a06d to your computer and use it in GitHub Desktop.
Anatomy of mirror strategy
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# MirroredStrategy"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 2,
"source": [
"import tensorflow as tf;\n",
"from tensorflow.contrib.distribute.python import mirrored_strategy\n",
"from tensorflow.contrib.distribute.python import collective_all_reduce_strategy"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## 1. Create `MirroredStrategy`"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 3,
"source": [
"strategy = mirrored_strategy.MirroredStrategy(['/gpu:0', '/gpu:1'])"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"INFO:tensorflow:Device is available but not used by distribute strategy: /device:CPU:0\n",
"INFO:tensorflow:Device is available but not used by distribute strategy: /device:XLA_GPU:0\n",
"INFO:tensorflow:Device is available but not used by distribute strategy: /device:XLA_GPU:1\n",
"INFO:tensorflow:Device is available but not used by distribute strategy: /device:XLA_CPU:0\n",
"INFO:tensorflow:Configured nccl all-reduce.\n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## 2. Create `MirroredVariable`\n",
"\n",
"Create a normal variable inside `strategy.scope()` to get a mirrored variable"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 4,
"source": [
"with tf.Session() as sess, strategy.scope():\n",
" mirrored_var = tf.Variable(1.0, name='mirror')\n",
" print(mirrored_var)\n",
"\n",
" sess.run(tf.global_variables_initializer())\n",
" print(sess.run(mirrored_var))"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"WARNING:tensorflow:From tensorflow/python/ops/resource_variable_ops.py:435: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Colocations handled automatically by placer.\n",
"MirroredVariable:{'/replica:0/task:0/device:GPU:0': <tf.Variable 'mirror:0' shape=() dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'mirror/replica_1:0' shape=() dtype=float32>}\n",
"1.0\n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"### Unwrap a mirrored variable\n",
"\n",
"A `MirroredVariable` composed of two Tensor copies on each GPU. You can use `strategy.unwrap` or `mirrored_var.get` to inspect each of them"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 5,
"source": [
"with tf.Session() as sess, strategy.scope():\n",
" mirrored_var = tf.Variable(1.0, name='mirror')\n",
" sess.run(tf.global_variables_initializer())\n",
"\n",
" # Method 1: strategy.unwrap\n",
" unwrap = strategy.unwrap(mirrored_var)\n",
" print('unwrap returns a pair of Tensors:', unwrap)\n",
" print('Eval unwrap:', sess.run(unwrap))\n",
"\n",
" # Method 2: variable.get(device_name)\n",
" var_dev1 = mirrored_var.get('/gpu:1')\n",
" print('Eval var_on_dev1:', sess.run(var_dev1))"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"unwrap returns a pair of Tensors: (<tf.Variable 'mirror_1:0' shape=() dtype=float32>, <tf.Variable 'mirror_1/replica_1:0' shape=() dtype=float32>)\n",
"Eval unwrap: (1.0, 1.0)\n",
"Eval var_on_dev1: 1.0\n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"### Replica context\n",
"\n",
"In scope of a distribution strategy, the context is \"cross-replica\" by default. Variables are of `MirroredVariable` and *identical* across each replica.\n",
"\n",
"But sometimes computations may diverge and result in different values on each device, e.g. when you invoke `call_for_each_replica(fn)`. This API will switch the context into non-cross-replica. Variables under this context will have the type of `PerReplica`, and they may not have identical values on each device."
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 6,
"source": [
"with tf.Session() as sess, strategy.scope():\n",
" def get_replica_id():\n",
" return tf.distribute.get_replica_context().replica_id_in_sync_group\n",
" replica_id = strategy.extended.call_for_each_replica(get_replica_id)\n",
" print(replica_id)\n",
" print(sess.run(strategy.unwrap(replica_id)))"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"PerReplica:{'/replica:0/task:0/device:GPU:0': <tf.Tensor 'Const:0' shape=() dtype=int32>, '/replica:0/task:0/device:GPU:1': <tf.Tensor 'Const_1:0' shape=() dtype=int32>}\n",
"(0, 1)\n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## 3. Update mirrored variable"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 7,
"source": [
"with strategy.scope():\n",
" mirrored_var = tf.Variable(1.0, name='mirror')\n",
" print(dir(mirrored_var))"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"['__abs__', '__add__', '__and__', '__class__', '__delattr__', '__dict__', '__dir__', '__div__', '__doc__', '__eq__', '__floordiv__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getitem__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__invert__', '__le__', '__lt__', '__matmul__', '__mod__', '__module__', '__mul__', '__ne__', '__neg__', '__new__', '__or__', '__pow__', '__radd__', '__rand__', '__rdiv__', '__reduce__', '__reduce_ex__', '__repr__', '__rfloordiv__', '__rmatmul__', '__rmod__', '__rmul__', '__ror__', '__rpow__', '__rsub__', '__rtruediv__', '__rxor__', '__setattr__', '__sizeof__', '__str__', '__sub__', '__subclasshook__', '__truediv__', '__weakref__', '__xor__', '_add_variable_with_custom_getter', '_aggregation', '_as_graph_element', '_assign_func', '_checkpoint_dependencies', '_common_name', '_deferred_dependencies', '_gather_saveables_for_checkpoint', '_get_cross_replica', '_handle_deferred_dependencies', '_in_graph_mode', '_index', '_initializer_op', '_keras_initialized', '_lookup_dependency', '_maybe_initialize_checkpointable', '_name_based_attribute_restore', '_no_dependency', '_preload_simple_restoration', '_primary_var', '_restore_from_checkpoint_position', '_shared_name', '_should_act_as_resource_variable', '_single_restoration_from_checkpoint_position', '_track_checkpointable', '_unique_id', 'aggregation', 'assign', 'assign_add', 'assign_sub', 'devices', 'dtype', 'get', 'get_shape', 'graph', 'initializer', 'is_initialized', 'is_tensor_like', 'name', 'op', 'read_value', 'shape', 'to_proto']\n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"According to above, mirrored variable has three update functions: `assign`, `assign_add`, `assign_sub`"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 8,
"source": [
"with tf.Session() as sess, strategy.scope():\n",
" mirrored_var = tf.Variable(1.0, name='mirror')\n",
" sess.run(tf.global_variables_initializer())\n",
" print('mirrored_var:', sess.run(mirrored_var))\n",
"\n",
" assigned_var = mirrored_var.assign(5.0)\n",
" unwrap = strategy.unwrap(assigned_var)\n",
" print('Eval assign and unwrap:', sess.run(unwrap))\n",
" print('mirrored_var after updated:', sess.run(mirrored_var))"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"mirrored_var: 1.0\n",
"Eval assign and unwrap: (5.0, 5.0)\n",
"mirrored_var after updated: 5.0\n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"### A more complicated example\n",
"\n",
"[Link](https://github.com/tensorflow/tensorflow/blob/v1.13.2/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py#L440)"
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"### Allreduce"
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"We are going update different values to the same `MirroredValue` on each device.\n",
"- Explicitly speficy aggregation mode on variable, such as `tf.VariableAggregation.MEAN`\n",
"- Use `strategy.extended.call_for_each_replica(fn)` to assign variable separately"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 9,
"source": [
"with tf.Session() as sess, strategy.scope():\n",
" mirrored_var = tf.Variable(1.0, name='mirror', aggregation=tf.VariableAggregation.MEAN) # specify reduce op `MEAN`\n",
" sess.run(tf.global_variables_initializer())\n",
" print('mirrored_var:', sess.run(mirrored_var))\n",
"\n",
" def model_fn():\n",
" rep_id = tf.distribute.get_replica_context().replica_id_in_sync_group\n",
" # GPU:0 assigns 0, GPU:1 assigns 1, resulting mirrored_var to be 0.5 \n",
" return mirrored_var.assign(tf.cast(rep_id, tf.float32))\n",
"\n",
" assigned_var = strategy.extended.call_for_each_replica(model_fn)\n",
" unwrap = strategy.unwrap(assigned_var)\n",
" print('Eval assign and unwrap:', sess.run(unwrap))\n",
" print('mirrored_var after updated:', sess.run(mirrored_var))"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"mirrored_var: 1.0\n",
"INFO:tensorflow:batch_all_reduce invoked for batches size = 1 with algorithm = nccl, num_packs = 1, agg_small_grads_max_bytes = 0 and agg_small_grads_max_group = 10\n",
"Eval assign and unwrap: (0.5, 0.5)\n",
"mirrored_var after updated: 0.5\n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"### Helper API"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 10,
"source": [
"from tensorflow.python.distribute import distribution_strategy_context\n",
"\n",
"with strategy.scope():\n",
" print(distribution_strategy_context.get_distribution_strategy())\n",
" print(distribution_strategy_context.has_distribution_strategy())\n",
" print(distribution_strategy_context.in_cross_replica_context())\n",
"\n",
" # `get_replica_context` returns None when in the cross-replica context\n",
" print(distribution_strategy_context.get_replica_context())"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"<tensorflow.contrib.distribute.python.mirrored_strategy.MirroredStrategy object at 0x7f49f0389080>\n",
"True\n",
"True\n",
"None\n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"### Simple training example"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 11,
"source": [
"from tensorflow.python.eager import backprop\n",
"from tensorflow.python.training import optimizer\n",
"\n",
"class MiniModel(tf.keras.Model):\n",
" def __init__(self):\n",
" super(MiniModel, self).__init__(name=\"\")\n",
" self.fc = tf.layers.Dense(1, kernel_initializer='ones', use_bias=False)\n",
"\n",
" def call(self, inputs, training=True):\n",
" return self.fc(inputs)\n",
"\n",
"with tf.Session() as sess, strategy.scope(), tf.variable_scope('', reuse=tf.AUTO_REUSE):\n",
" mock_model = MiniModel()\n",
"\n",
" def loss_fn(ctx):\n",
" del ctx\n",
" return mock_model(tf.ones([1, 10]))\n",
"\n",
" gradients_fn = backprop.implicit_grad(loss_fn)\n",
" gradients_fn = optimizer.get_filtered_grad_fn(gradients_fn)\n",
" grads_and_vars = strategy.extended.call_for_each_replica(gradients_fn, args=(None,))\n",
" print('grad and var of dense weights:\\n', grads_and_vars)\n",
"\n",
" optimizer = tf.train.GradientDescentOptimizer(0.25)\n",
" update_ops = optimizer._distributed_apply(strategy, grads_and_vars)\n",
"\n",
" sess.run(tf.global_variables_initializer())\n",
" sess.run(update_ops)\n",
"\n",
" updated_var_values = sess.run(mock_model.variables)\n",
" # weights start at 1.0 and get two updates of 0.25\n",
" print('updated dense weights:\\n', updated_var_values)"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"grad and var of dense weights:\n",
" [(PerReplica({'/replica:0/task:0/device:GPU:0': <tf.Tensor 'MatMul_1:0' shape=(10, 1) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Tensor 'replica_1/MatMul_1:0' shape=(10, 1) dtype=float32>}), MirroredVariable({'/replica:0/task:0/device:GPU:0': <tf.Variable 'dense/kernel:0' shape=(10, 1) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'dense/kernel/replica_1:0' shape=(10, 1) dtype=float32>}))]\n",
"INFO:tensorflow:batch_all_reduce invoked for batches size = 1 with algorithm = nccl, num_packs = 1, agg_small_grads_max_bytes = 0 and agg_small_grads_max_group = 10\n",
"updated dense weights:\n",
" [array([[0.5],\n",
" [0.5],\n",
" [0.5],\n",
" [0.5],\n",
" [0.5],\n",
" [0.5],\n",
" [0.5],\n",
" [0.5],\n",
" [0.5],\n",
" [0.5]], dtype=float32)]\n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"# MultiWorkerMirroredStrategy (called CollectiveAllReduceStrategy in v1)"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 12,
"source": [
"strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy(num_gpus_per_worker=2)"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"INFO:tensorflow:Device is available but not used by distribute strategy: /device:CPU:0\n",
"INFO:tensorflow:Device is available but not used by distribute strategy: /device:XLA_GPU:0\n",
"INFO:tensorflow:Device is available but not used by distribute strategy: /device:XLA_GPU:1\n",
"INFO:tensorflow:Device is available but not used by distribute strategy: /device:XLA_CPU:0\n",
"INFO:tensorflow:Configured nccl all-reduce.\n",
"INFO:tensorflow:CollectiveAllReduceStrategy with local_devices = ('/device:GPU:0', '/device:GPU:1')\n"
]
}
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"### Rerun all cells from **\"2. Create MirroredVariable\"**"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"source": [],
"outputs": [],
"metadata": {}
}
],
"metadata": {
"orig_nbformat": 4,
"language_info": {
"name": "python",
"version": "3.6.13",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3.6.13 64-bit ('bigmodel': conda)"
},
"interpreter": {
"hash": "edb6864056ee8b47b1401676f7718f83fb3a883ab7a06a16a9767e47c957af6c"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment