Created
February 7, 2017 16:42
-
-
Save mrocklin/aecfa5733ec1493dc928f59d8695af0f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"\n", | |
"<img src=\"http://dask.readthedocs.io/en/latest/_images/dask_horizontal.svg\"\n", | |
" align=\"right\"\n", | |
" width=\"30%\"\n", | |
" alt=\"Dask logo\">\n", | |
"\n", | |
"\n", | |
"<img src=\"https://camo.githubusercontent.com/ee91ac3c9f5ad840ebf70b54284498fe0e6ddb92/68747470733a2f2f7777772e74656e736f72666c6f772e6f72672f696d616765732f74665f6c6f676f5f7472616e73702e706e67\"\n", | |
" align=\"right\"\n", | |
" width=\"25%\"\n", | |
" alt=\"Tensorflow logo\">\n", | |
"\n", | |
"\n", | |
"Tensorflow and Dask\n", | |
"=======================\n", | |
"\n", | |
"1. Load up MNIST data in parallel across a Dask cluster\n", | |
"2. Filter and batch this data\n", | |
"3. Start distributed Tensorflow servers on top of Dask\n", | |
"4. Follow a [Tensorflow example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/dist_test/python/mnist_replica.py) to create a distributed training computation\n", | |
"5. Use Dask to distribute this job on Tensorflow servers\n", | |
"6. Feed Tensorflow servers from Dask workers\n", | |
"\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"import tensorflow as tf\n", | |
"\n", | |
"from distributed import LocalCluster, Client, local_client, progress\n", | |
"cluster = LocalCluster(nanny=True, n_workers=8, threads_per_worker=1)\n", | |
"c = Client(cluster)\n", | |
"c" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Load Data\n", | |
"\n", | |
"We don't have much data, so we artifically inflate ten times.\n", | |
"\n", | |
"This shows how to build a pair of large dask.arrays from generic Python functions" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"def get_mnist():\n", | |
" from tensorflow.examples.tutorials.mnist import input_data\n", | |
" mnist = input_data.read_data_sets('/tmp/mnist-data', one_hot=True)\n", | |
" return mnist.train.images, mnist.train.labels\n", | |
"\n", | |
"import dask.array as da\n", | |
"from dask import delayed\n", | |
"\n", | |
"datasets = [delayed(get_mnist)() for i in range(10)]\n", | |
"images = [d[0] for d in datasets]\n", | |
"labels = [d[1] for d in datasets]\n", | |
"\n", | |
"images = [da.from_delayed(im, shape=(55000, 784), dtype='float32') for im in images]\n", | |
"labels = [da.from_delayed(la, shape=(55000, 10), dtype='float32') for la in labels]\n", | |
"\n", | |
"images = da.concatenate(images, axis=0)\n", | |
"labels = da.concatenate(labels, axis=0)\n", | |
"\n", | |
"images" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"images, labels = c.persist([images, labels])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"import matplotlib.pyplot as plt\n", | |
"%matplotlib inline" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"images" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"plt.imshow(images[0].compute().reshape((28, 28)), cmap='gray')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"plt.imshow(images.mean(axis=0).compute().reshape((28, 28)), cmap='gray')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Rechunk to smaller batches and combine\n", | |
"\n", | |
"Tensorflow seems to prefer smaller bits of data. Lets shard our datset a bit." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"images" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"images = images.rechunk((1000, 784))\n", | |
"labels = labels.rechunk((1000, 10))\n", | |
"\n", | |
"images = images.to_delayed().flatten().tolist()\n", | |
"labels = labels.to_delayed().flatten().tolist()\n", | |
"batches = [delayed([im, la]) for im, la in zip(images, labels)]\n", | |
"\n", | |
"batches = c.compute(batches)\n", | |
"progress(batches)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Launch Tensorflow servers alongside Dask workers" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from dask_tensorflow import start_tensorflow\n", | |
"tf_spec, dask_spec = start_tensorflow(c, ps=1, worker=4, scorer=1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"tf_spec.as_dict()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"dask_spec" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Worker tasks create model, collect data from Dask" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"import math\n", | |
"import tempfile\n", | |
"import time\n", | |
"from queue import Empty\n", | |
"\n", | |
"IMAGE_PIXELS = 28\n", | |
"hidden_units = 100\n", | |
"learning_rate = 0.01\n", | |
"sync_replicas = False\n", | |
"replicas_to_aggregate = len(dask_spec['worker'])\n", | |
"\n", | |
"def model(server):\n", | |
" worker_device = \"/job:%s/task:%d\" % (server.server_def.job_name, \n", | |
" server.server_def.task_index)\n", | |
" task_index = server.server_def.task_index\n", | |
" is_chief = task_index == 0\n", | |
"\n", | |
" with tf.device(tf.train.replica_device_setter(\n", | |
" worker_device=worker_device,\n", | |
" ps_device=\"/job:ps/cpu:0\",\n", | |
" cluster=tf_spec)):\n", | |
"\n", | |
" global_step = tf.Variable(0, name=\"global_step\", trainable=False)\n", | |
"\n", | |
" # Variables of the hidden layer\n", | |
" hid_w = tf.Variable(\n", | |
" tf.truncated_normal(\n", | |
" [IMAGE_PIXELS * IMAGE_PIXELS, hidden_units],\n", | |
" stddev=1.0 / IMAGE_PIXELS),\n", | |
" name=\"hid_w\")\n", | |
" hid_b = tf.Variable(tf.zeros([hidden_units]), name=\"hid_b\")\n", | |
"\n", | |
" # Variables of the softmax layer\n", | |
" sm_w = tf.Variable(\n", | |
" tf.truncated_normal(\n", | |
" [hidden_units, 10],\n", | |
" stddev=1.0 / math.sqrt(hidden_units)),\n", | |
" name=\"sm_w\")\n", | |
" sm_b = tf.Variable(tf.zeros([10]), name=\"sm_b\")\n", | |
"\n", | |
" # Ops: located on the worker specified with task_index\n", | |
" x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS])\n", | |
" y_ = tf.placeholder(tf.float32, [None, 10])\n", | |
"\n", | |
" hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)\n", | |
" hid = tf.nn.relu(hid_lin)\n", | |
"\n", | |
" y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))\n", | |
" cross_entropy = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))\n", | |
"\n", | |
" opt = tf.train.AdamOptimizer(learning_rate)\n", | |
"\n", | |
" if sync_replicas:\n", | |
" if replicas_to_aggregate is None:\n", | |
" replicas_to_aggregate = num_workers\n", | |
" else:\n", | |
" replicas_to_aggregate = replicas_to_aggregate\n", | |
"\n", | |
" opt = tf.train.SyncReplicasOptimizer(\n", | |
" opt,\n", | |
" replicas_to_aggregate=replicas_to_aggregate,\n", | |
" total_num_replicas=num_workers,\n", | |
" name=\"mnist_sync_replicas\")\n", | |
"\n", | |
" train_step = opt.minimize(cross_entropy, global_step=global_step)\n", | |
"\n", | |
" if sync_replicas:\n", | |
" local_init_op = opt.local_step_init_op\n", | |
" if is_chief:\n", | |
" local_init_op = opt.chief_init_op\n", | |
"\n", | |
" ready_for_local_init_op = opt.ready_for_local_init_op\n", | |
"\n", | |
" # Initial token and chief queue runners required by the sync_replicas mode\n", | |
" chief_queue_runner = opt.get_chief_queue_runner()\n", | |
" sync_init_op = opt.get_init_tokens_op()\n", | |
"\n", | |
" init_op = tf.global_variables_initializer()\n", | |
" train_dir = tempfile.mkdtemp()\n", | |
"\n", | |
" if sync_replicas:\n", | |
" sv = tf.train.Supervisor(\n", | |
" is_chief=is_chief,\n", | |
" logdir=train_dir,\n", | |
" init_op=init_op,\n", | |
" local_init_op=local_init_op,\n", | |
" ready_for_local_init_op=ready_for_local_init_op,\n", | |
" recovery_wait_secs=1,\n", | |
" global_step=global_step)\n", | |
" else:\n", | |
" sv = tf.train.Supervisor(\n", | |
" is_chief=is_chief,\n", | |
" logdir=train_dir,\n", | |
" init_op=init_op,\n", | |
" recovery_wait_secs=1,\n", | |
" global_step=global_step)\n", | |
"\n", | |
" sess_config = tf.ConfigProto(\n", | |
" allow_soft_placement=True,\n", | |
" log_device_placement=False,\n", | |
" device_filters=[\"/job:ps\", \"/job:worker/task:%d\" % task_index])\n", | |
"\n", | |
" # The chief worker (task_index==0) session will prepare the session,\n", | |
" # while the remaining workers will wait for the preparation to complete.\n", | |
" if is_chief:\n", | |
" print(\"Worker %d: Initializing session...\" % task_index)\n", | |
" else:\n", | |
" print(\"Worker %d: Waiting for session to be initialized...\" %\n", | |
" task_index)\n", | |
"\n", | |
" sess = sv.prepare_or_wait_for_session(server.target, config=sess_config)\n", | |
"\n", | |
" if sync_replicas and is_chief:\n", | |
" # Chief worker will start the chief queue runner and call the init op.\n", | |
" sess.run(sync_init_op)\n", | |
" sv.start_queue_runners(sess, [chief_queue_runner])\n", | |
" \n", | |
" return sess, x, y_, train_step, global_step, cross_entropy\n", | |
" \n", | |
" \n", | |
"def worker_task():\n", | |
" with local_client() as c:\n", | |
" scores = c.channel('scores')\n", | |
" num_workers = replicas_to_aggregate = len(dask_spec['worker'])\n", | |
" \n", | |
" server = c.worker.tensorflow_server\n", | |
" queue = c.worker.tensorflow_queue\n", | |
"\n", | |
" sess, x, y_, train_step, global_step, cross_entropy = model(c.worker.tensorflow_server)\n", | |
" while not scores or scores.data[-1] > 1000:\n", | |
" try:\n", | |
" batch = queue.get(timeout=0.5)\n", | |
" except Empty:\n", | |
" continue\n", | |
" \n", | |
" train_feed = {x: batch[0], y_: batch[1]}\n", | |
"\n", | |
" _, step = sess.run([train_step, global_step], \n", | |
" feed_dict=train_feed)\n", | |
"\n", | |
" return step\n", | |
" \n", | |
"\n", | |
"def scoring_task():\n", | |
" with local_client() as c:\n", | |
" # Scores Channel\n", | |
" scores = c.channel('scores', maxlen=10)\n", | |
"\n", | |
" # Make Model\n", | |
" server = c.worker.tensorflow_server\n", | |
" sess, x, y_, train_step, global_step, cross_entropy = model(c.worker.tensorflow_server)\n", | |
" \n", | |
" # Testing Data\n", | |
" from tensorflow.examples.tutorials.mnist import input_data\n", | |
" mnist = input_data.read_data_sets('/tmp/mnist-data', one_hot=True)\n", | |
" feed = {x: mnist.validation.images, \n", | |
" y_: mnist.validation.labels}\n", | |
"\n", | |
" # Main Loop\n", | |
" while True:\n", | |
" score = sess.run(cross_entropy, feed_dict=feed)\n", | |
" scores.append(float(score))\n", | |
"\n", | |
" time.sleep(1)\n", | |
" return step\n", | |
"\n", | |
"\n", | |
"def ps_task():\n", | |
" with local_client() as c:\n", | |
" c.worker.tensorflow_server.join()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Tell Tensorflow Servers what to do" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"ps_tasks = [c.submit(ps_task, workers=worker) \n", | |
" for worker in dask_spec['ps']]\n", | |
"\n", | |
"worker_tasks = [c.submit(worker_task, workers=addr, pure=False)\n", | |
" for addr in dask_spec['worker']]\n", | |
"\n", | |
"scorer_task = c.submit(scoring_task, workers=dask_spec['scorer'][0])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Feed data to training workers" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"from distributed.worker_client import get_worker\n", | |
"\n", | |
"def transfer_dask_to_tensorflow(batch):\n", | |
" worker = get_worker()\n", | |
" worker.tensorflow_queue.put(batch) " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"dump = c.map(transfer_dask_to_tensorflow, batches, workers=dask_spec['worker'], pure=False)\n", | |
"progress(dump)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Listen to scores channel" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"scores = c.channel('scores', maxlen=10)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"scores.data" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Questions:\n", | |
"\n", | |
"1. Should we throttle dask-to-tensorflow data dumps?\n", | |
"2. How do we reclaim the model from the parameter servers?\n", | |
"3. What is the right way to test against testing data?" | |
] | |
} | |
], | |
"metadata": { | |
"anaconda-cloud": {}, | |
"kernelspec": { | |
"display_name": "Python [default]", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.2" | |
}, | |
"widgets": { | |
"state": { | |
"99357fbaa78e47218d954444e0f6a4e9": { | |
"views": [ | |
{ | |
"cell_index": 22 | |
} | |
] | |
}, | |
"bc80b5de65d84a2789667f23f6f7cad1": { | |
"views": [ | |
{ | |
"cell_index": 11 | |
} | |
] | |
} | |
}, | |
"version": "1.2.0" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment