Last active
July 29, 2018 17:40
-
-
Save stsievert/7135380e1227236bde03a852cae93a37 to your computer and use it in GitHub Desktop.
Parameter server and actors
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<table style=\"border: 2px solid white;\">\n", | |
"<tr>\n", | |
"<td style=\"vertical-align: top; border: 0px solid white\">\n", | |
"<h3>Client</h3>\n", | |
"<ul>\n", | |
" <li><b>Scheduler: </b>tcp://127.0.0.1:49897\n", | |
" <li><b>Dashboard: </b><a href='http://127.0.0.1:49898/status' target='_blank'>http://127.0.0.1:49898/status</a>\n", | |
"</ul>\n", | |
"</td>\n", | |
"<td style=\"vertical-align: top; border: 0px solid white\">\n", | |
"<h3>Cluster</h3>\n", | |
"<ul>\n", | |
" <li><b>Workers: </b>8</li>\n", | |
" <li><b>Cores: </b>8</li>\n", | |
" <li><b>Memory: </b>17.18 GB</li>\n", | |
"</ul>\n", | |
"</td>\n", | |
"</tr>\n", | |
"</table>" | |
], | |
"text/plain": [ | |
"<Client: scheduler='tcp://127.0.0.1:49897' processes=8 cores=8>" | |
] | |
}, | |
"execution_count": 1, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from distributed import Client\n", | |
"import numpy as np\n", | |
"client = Client()\n", | |
"client" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Worker:\n", | |
" def __init__(self, model, n_models, worker_id):\n", | |
" self.model = model\n", | |
" self.grads = []\n", | |
" self.n_models = n_models\n", | |
" self.worker_id = worker_id\n", | |
" \n", | |
" def _model(self):\n", | |
" return self.model\n", | |
" \n", | |
" def compute(self):\n", | |
" self.grad = self.worker_id\n", | |
" self.grads += [self.grad]\n", | |
" return True\n", | |
" \n", | |
" def send(self, worker):\n", | |
" worker.recv(self.grad)\n", | |
" \n", | |
" def recv(self, grad):\n", | |
" self.grads += [grad]\n", | |
" \n", | |
" def reduce(self):\n", | |
" assert len(self.grads) == 4\n", | |
" self.model += sum(self.grads)\n", | |
" self.grads = []" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[<Actor: Worker, key=Worker-d0b833da-1021-4801-8541-0ec59730ddb4>,\n", | |
" <Actor: Worker, key=Worker-ed163ba0-e19f-4482-be5c-0d7889ebe197>,\n", | |
" <Actor: Worker, key=Worker-aa643321-4ab1-4693-8c4a-b37da76ca536>,\n", | |
" <Actor: Worker, key=Worker-7afd49ea-9d63-48f8-963c-bf86fac1c5ec>]" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model = 0\n", | |
"n_models = 4\n", | |
"futures = [client.submit(Worker, model, n_models, i, actor=True) \n", | |
" for i in range(n_models)]\n", | |
"workers = client.gather(futures)\n", | |
"workers" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"result of client.gather() are ActorFutures? [<ActorFuture>, <ActorFuture>, <ActorFuture>, <ActorFuture>]\n", | |
"models = [6, 6, 6, 6]\n", | |
"models = [12, 12, 12, 12]\n", | |
"models = [18, 18, 18, 18]\n", | |
"models = [24, 24, 24, 24]\n" | |
] | |
} | |
], | |
"source": [ | |
"for k in range(4):\n", | |
" # calculate\n", | |
" futures = [worker.compute() for worker in workers]\n", | |
" client.gather(futures)\n", | |
"\n", | |
" # communicate\n", | |
" # (updating model hapepns internally; worker knows when fully received model)\n", | |
" # (this could be an all-reduce implementation if desired)\n", | |
" futures = []\n", | |
" for i, w1 in enumerate(workers):\n", | |
" for j, w2 in enumerate(workers):\n", | |
" if i == j:\n", | |
" continue\n", | |
" else:\n", | |
" futures += [w1.send(w2)]\n", | |
" client.gather(futures)\n", | |
"\n", | |
" # update model\n", | |
" futures = [worker.reduce() for worker in workers]\n", | |
" client.gather(futures)\n", | |
"\n", | |
" # quick test; make sure all models are the same\n", | |
" futures = [worker._model() for worker in workers]\n", | |
" if k == 0:\n", | |
" print(\"result of client.gather() are ActorFutures?\", client.gather(futures))\n", | |
" \n", | |
" models = [m.result() for m in futures]\n", | |
" print(\"models =\", models)\n", | |
" assert all(model == models[0] for model in models)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[<ActorFuture>, <ActorFuture>, <ActorFuture>, <ActorFuture>]" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"grads = [worker.compute() for worker in workers]\n", | |
"grads" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment