Skip to content

Instantly share code, notes, and snippets.

@mrocklin
Created February 7, 2017 16:42
Show Gist options
  • Save mrocklin/aecfa5733ec1493dc928f59d8695af0f to your computer and use it in GitHub Desktop.
Save mrocklin/aecfa5733ec1493dc928f59d8695af0f to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"<img src=\"http://dask.readthedocs.io/en/latest/_images/dask_horizontal.svg\"\n",
" align=\"right\"\n",
" width=\"30%\"\n",
" alt=\"Dask logo\">\n",
"\n",
"\n",
"<img src=\"https://camo.githubusercontent.com/ee91ac3c9f5ad840ebf70b54284498fe0e6ddb92/68747470733a2f2f7777772e74656e736f72666c6f772e6f72672f696d616765732f74665f6c6f676f5f7472616e73702e706e67\"\n",
" align=\"right\"\n",
" width=\"25%\"\n",
" alt=\"Tensorflow logo\">\n",
"\n",
"\n",
"Tensorflow and Dask\n",
"=======================\n",
"\n",
"1. Load up MNIST data in parallel across a Dask cluster\n",
"2. Filter and batch this data\n",
"3. Start distributed Tensorflow servers on top of Dask\n",
"4. Follow a [Tensorflow example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/dist_test/python/mnist_replica.py) to create a distributed training computation\n",
"5. Use Dask to distribute this job on Tensorflow servers\n",
"6. Feed Tensorflow servers from Dask workers\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"\n",
"from distributed import LocalCluster, Client, local_client, progress\n",
"cluster = LocalCluster(nanny=True, n_workers=8, threads_per_worker=1)\n",
"c = Client(cluster)\n",
"c"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load Data\n",
"\n",
"We don't have much data, so we artifically inflate ten times.\n",
"\n",
"This shows how to build a pair of large dask.arrays from generic Python functions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def get_mnist():\n",
" from tensorflow.examples.tutorials.mnist import input_data\n",
" mnist = input_data.read_data_sets('/tmp/mnist-data', one_hot=True)\n",
" return mnist.train.images, mnist.train.labels\n",
"\n",
"import dask.array as da\n",
"from dask import delayed\n",
"\n",
"datasets = [delayed(get_mnist)() for i in range(10)]\n",
"images = [d[0] for d in datasets]\n",
"labels = [d[1] for d in datasets]\n",
"\n",
"images = [da.from_delayed(im, shape=(55000, 784), dtype='float32') for im in images]\n",
"labels = [da.from_delayed(la, shape=(55000, 10), dtype='float32') for la in labels]\n",
"\n",
"images = da.concatenate(images, axis=0)\n",
"labels = da.concatenate(labels, axis=0)\n",
"\n",
"images"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"images, labels = c.persist([images, labels])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"images"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"plt.imshow(images[0].compute().reshape((28, 28)), cmap='gray')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"plt.imshow(images.mean(axis=0).compute().reshape((28, 28)), cmap='gray')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Rechunk to smaller batches and combine\n",
"\n",
"Tensorflow seems to prefer smaller bits of data. Lets shard our datset a bit."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"images"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"images = images.rechunk((1000, 784))\n",
"labels = labels.rechunk((1000, 10))\n",
"\n",
"images = images.to_delayed().flatten().tolist()\n",
"labels = labels.to_delayed().flatten().tolist()\n",
"batches = [delayed([im, la]) for im, la in zip(images, labels)]\n",
"\n",
"batches = c.compute(batches)\n",
"progress(batches)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Launch Tensorflow servers alongside Dask workers"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from dask_tensorflow import start_tensorflow\n",
"tf_spec, dask_spec = start_tensorflow(c, ps=1, worker=4, scorer=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"tf_spec.as_dict()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"dask_spec"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Worker tasks create model, collect data from Dask"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import math\n",
"import tempfile\n",
"import time\n",
"from queue import Empty\n",
"\n",
"IMAGE_PIXELS = 28\n",
"hidden_units = 100\n",
"learning_rate = 0.01\n",
"sync_replicas = False\n",
"replicas_to_aggregate = len(dask_spec['worker'])\n",
"\n",
"def model(server):\n",
" worker_device = \"/job:%s/task:%d\" % (server.server_def.job_name, \n",
" server.server_def.task_index)\n",
" task_index = server.server_def.task_index\n",
" is_chief = task_index == 0\n",
"\n",
" with tf.device(tf.train.replica_device_setter(\n",
" worker_device=worker_device,\n",
" ps_device=\"/job:ps/cpu:0\",\n",
" cluster=tf_spec)):\n",
"\n",
" global_step = tf.Variable(0, name=\"global_step\", trainable=False)\n",
"\n",
" # Variables of the hidden layer\n",
" hid_w = tf.Variable(\n",
" tf.truncated_normal(\n",
" [IMAGE_PIXELS * IMAGE_PIXELS, hidden_units],\n",
" stddev=1.0 / IMAGE_PIXELS),\n",
" name=\"hid_w\")\n",
" hid_b = tf.Variable(tf.zeros([hidden_units]), name=\"hid_b\")\n",
"\n",
" # Variables of the softmax layer\n",
" sm_w = tf.Variable(\n",
" tf.truncated_normal(\n",
" [hidden_units, 10],\n",
" stddev=1.0 / math.sqrt(hidden_units)),\n",
" name=\"sm_w\")\n",
" sm_b = tf.Variable(tf.zeros([10]), name=\"sm_b\")\n",
"\n",
" # Ops: located on the worker specified with task_index\n",
" x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS])\n",
" y_ = tf.placeholder(tf.float32, [None, 10])\n",
"\n",
" hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)\n",
" hid = tf.nn.relu(hid_lin)\n",
"\n",
" y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))\n",
" cross_entropy = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))\n",
"\n",
" opt = tf.train.AdamOptimizer(learning_rate)\n",
"\n",
" if sync_replicas:\n",
" if replicas_to_aggregate is None:\n",
" replicas_to_aggregate = num_workers\n",
" else:\n",
" replicas_to_aggregate = replicas_to_aggregate\n",
"\n",
" opt = tf.train.SyncReplicasOptimizer(\n",
" opt,\n",
" replicas_to_aggregate=replicas_to_aggregate,\n",
" total_num_replicas=num_workers,\n",
" name=\"mnist_sync_replicas\")\n",
"\n",
" train_step = opt.minimize(cross_entropy, global_step=global_step)\n",
"\n",
" if sync_replicas:\n",
" local_init_op = opt.local_step_init_op\n",
" if is_chief:\n",
" local_init_op = opt.chief_init_op\n",
"\n",
" ready_for_local_init_op = opt.ready_for_local_init_op\n",
"\n",
" # Initial token and chief queue runners required by the sync_replicas mode\n",
" chief_queue_runner = opt.get_chief_queue_runner()\n",
" sync_init_op = opt.get_init_tokens_op()\n",
"\n",
" init_op = tf.global_variables_initializer()\n",
" train_dir = tempfile.mkdtemp()\n",
"\n",
" if sync_replicas:\n",
" sv = tf.train.Supervisor(\n",
" is_chief=is_chief,\n",
" logdir=train_dir,\n",
" init_op=init_op,\n",
" local_init_op=local_init_op,\n",
" ready_for_local_init_op=ready_for_local_init_op,\n",
" recovery_wait_secs=1,\n",
" global_step=global_step)\n",
" else:\n",
" sv = tf.train.Supervisor(\n",
" is_chief=is_chief,\n",
" logdir=train_dir,\n",
" init_op=init_op,\n",
" recovery_wait_secs=1,\n",
" global_step=global_step)\n",
"\n",
" sess_config = tf.ConfigProto(\n",
" allow_soft_placement=True,\n",
" log_device_placement=False,\n",
" device_filters=[\"/job:ps\", \"/job:worker/task:%d\" % task_index])\n",
"\n",
" # The chief worker (task_index==0) session will prepare the session,\n",
" # while the remaining workers will wait for the preparation to complete.\n",
" if is_chief:\n",
" print(\"Worker %d: Initializing session...\" % task_index)\n",
" else:\n",
" print(\"Worker %d: Waiting for session to be initialized...\" %\n",
" task_index)\n",
"\n",
" sess = sv.prepare_or_wait_for_session(server.target, config=sess_config)\n",
"\n",
" if sync_replicas and is_chief:\n",
" # Chief worker will start the chief queue runner and call the init op.\n",
" sess.run(sync_init_op)\n",
" sv.start_queue_runners(sess, [chief_queue_runner])\n",
" \n",
" return sess, x, y_, train_step, global_step, cross_entropy\n",
" \n",
" \n",
"def worker_task():\n",
" with local_client() as c:\n",
" scores = c.channel('scores')\n",
" num_workers = replicas_to_aggregate = len(dask_spec['worker'])\n",
" \n",
" server = c.worker.tensorflow_server\n",
" queue = c.worker.tensorflow_queue\n",
"\n",
" sess, x, y_, train_step, global_step, cross_entropy = model(c.worker.tensorflow_server)\n",
" while not scores or scores.data[-1] > 1000:\n",
" try:\n",
" batch = queue.get(timeout=0.5)\n",
" except Empty:\n",
" continue\n",
" \n",
" train_feed = {x: batch[0], y_: batch[1]}\n",
"\n",
" _, step = sess.run([train_step, global_step], \n",
" feed_dict=train_feed)\n",
"\n",
" return step\n",
" \n",
"\n",
"def scoring_task():\n",
" with local_client() as c:\n",
" # Scores Channel\n",
" scores = c.channel('scores', maxlen=10)\n",
"\n",
" # Make Model\n",
" server = c.worker.tensorflow_server\n",
" sess, x, y_, train_step, global_step, cross_entropy = model(c.worker.tensorflow_server)\n",
" \n",
" # Testing Data\n",
" from tensorflow.examples.tutorials.mnist import input_data\n",
" mnist = input_data.read_data_sets('/tmp/mnist-data', one_hot=True)\n",
" feed = {x: mnist.validation.images, \n",
" y_: mnist.validation.labels}\n",
"\n",
" # Main Loop\n",
" while True:\n",
" score = sess.run(cross_entropy, feed_dict=feed)\n",
" scores.append(float(score))\n",
"\n",
" time.sleep(1)\n",
" return step\n",
"\n",
"\n",
"def ps_task():\n",
" with local_client() as c:\n",
" c.worker.tensorflow_server.join()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Tell Tensorflow Servers what to do"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"ps_tasks = [c.submit(ps_task, workers=worker) \n",
" for worker in dask_spec['ps']]\n",
"\n",
"worker_tasks = [c.submit(worker_task, workers=addr, pure=False)\n",
" for addr in dask_spec['worker']]\n",
"\n",
"scorer_task = c.submit(scoring_task, workers=dask_spec['scorer'][0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Feed data to training workers"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from distributed.worker_client import get_worker\n",
"\n",
"def transfer_dask_to_tensorflow(batch):\n",
" worker = get_worker()\n",
" worker.tensorflow_queue.put(batch) "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"dump = c.map(transfer_dask_to_tensorflow, batches, workers=dask_spec['worker'], pure=False)\n",
"progress(dump)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Listen to scores channel"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"scores = c.channel('scores', maxlen=10)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"scores.data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Questions:\n",
"\n",
"1. Should we throttle dask-to-tensorflow data dumps?\n",
"2. How do we reclaim the model from the parameter servers?\n",
"3. What is the right way to test against testing data?"
]
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python [default]",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
},
"widgets": {
"state": {
"99357fbaa78e47218d954444e0f6a4e9": {
"views": [
{
"cell_index": 22
}
]
},
"bc80b5de65d84a2789667f23f6f7cad1": {
"views": [
{
"cell_index": 11
}
]
}
},
"version": "1.2.0"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment