Created
July 16, 2019 16:18
-
-
Save matthieubulte/db1420e03a4287f253bddecda349dbdf to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2019-07-16T16:10:34.057490Z", | |
"start_time": "2019-07-16T16:10:33.422048Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"import cupy as cp\n", | |
"import numpy as np\n", | |
"\n", | |
"import dask\n", | |
"from dask_cuda import LocalCUDACluster\n", | |
"from dask.distributed import Client\n", | |
"import dask.array as da\n", | |
"\n", | |
"client = None\n", | |
"cluster = None\n", | |
"\n", | |
"class Plugin(object):\n", | |
" def setup(self, worker):\n", | |
" pass\n", | |
" \n", | |
" def teardown(self, worker):\n", | |
" pass\n", | |
" \n", | |
" def finish_action(self, result):\n", | |
" if type(result) is cp.ndarray:\n", | |
" result.device.synchronize() \n", | |
"\n", | |
"def dask_setup(n_gpus, mount_plugin):\n", | |
" global client, cluster\n", | |
" if client is not None:\n", | |
" client.close()\n", | |
" cluster.close()\n", | |
" \n", | |
" cluster = LocalCUDACluster(n_workers=n_gpus)\n", | |
" client = Client(cluster)\n", | |
"\n", | |
" if mount_plugin:\n", | |
" client.register_worker_plugin(Plugin())\n", | |
"\n", | |
" rs = da.random.RandomState()\n", | |
"\n", | |
" X = rs.random(size=(10**7, 10**2), chunks=(10**6, 10**2))\n", | |
" y = rs.random(size=(10**7,), chunks=(10**6,))\n", | |
" beta = rs.random(size=(X.shape[1],))\n", | |
"\n", | |
" X = X.map_blocks(cp.array, dtype=np.float64)\n", | |
" y = y.map_blocks(cp.array, dtype=np.float64)\n", | |
" beta = beta.map_blocks(cp.array, dtype=np.float64)\n", | |
" \n", | |
" return X, y, beta\n", | |
"\n", | |
"def dask_compute(X, y, beta):\n", | |
" Xbeta = X.dot(beta)\n", | |
" p = da.tanh(Xbeta * 0.5) * 0.5 + 0.5\n", | |
" loss = -da.log(1.0 - y + (2.0 * y - 1.0) * p).sum()\n", | |
" grad = da.dot(X.T, p - 1)\n", | |
"\n", | |
" dask.compute([loss, grad], optimize_graph=True);" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# CuPy: 1 single sync" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2019-07-16T16:10:34.425953Z", | |
"start_time": "2019-07-16T16:10:34.059327Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"device = cp.cuda.device.Device()\n", | |
"\n", | |
"X = cp.random.random((10**6, 10**2))\n", | |
"y = cp.random.random(10**6)\n", | |
"beta = cp.random.random(10**2)\n", | |
"\n", | |
"device.synchronize()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2019-07-16T16:10:36.405481Z", | |
"start_time": "2019-07-16T16:10:34.428019Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"203 ms ± 37.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit\n", | |
"\n", | |
"Xbeta = X.dot(beta)\n", | |
"p = cp.tanh(Xbeta * 0.5) * 0.5 + 0.5\n", | |
"loss = -cp.log(1.0 - y + (2.0 * y - 1.0) * p).sum()\n", | |
"grad = cp.dot(X.T, p - 1)\n", | |
"\n", | |
"device.synchronize()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# CuPy: 1 sync/op" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2019-07-16T16:10:50.338595Z", | |
"start_time": "2019-07-16T16:10:36.407386Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"172 ms ± 31.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit\n", | |
"\n", | |
"Xbeta = X.dot(beta)\n", | |
"device.synchronize()\n", | |
"\n", | |
"half_Xbeta = Xbeta * 0.5\n", | |
"device.synchronize()\n", | |
"\n", | |
"tanh_half_Xbeta = cp.tanh(half_Xbeta)\n", | |
"device.synchronize()\n", | |
"\n", | |
"p_shifted = tanh_half_Xbeta * 0.5\n", | |
"device.synchronize()\n", | |
"\n", | |
"p = p_shifted + 0.5\n", | |
"device.synchronize()\n", | |
"\n", | |
"two_y = 2.0*y\n", | |
"device.synchronize()\n", | |
"\n", | |
"centered_y = two_y - 1.0\n", | |
"device.synchronize()\n", | |
"\n", | |
"cyp = centered_y * p\n", | |
"device.synchronize()\n", | |
"\n", | |
"cyp_shifted = cyp + 1.0\n", | |
"device.synchronize()\n", | |
"\n", | |
"err = cyp_shifted - y\n", | |
"device.synchronize()\n", | |
"\n", | |
"min_loss = cp.log(err)\n", | |
"device.synchronize()\n", | |
"\n", | |
"loss = -min_loss.sum()\n", | |
"device.synchronize()\n", | |
"\n", | |
"p_m_1 = p - 1.0\n", | |
"device.synchronize()\n", | |
"\n", | |
"grad = cp.dot(X.T, p_m_1)\n", | |
"device.synchronize()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Dask. GPUs: 2 | sync: True" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2019-07-16T16:10:51.884942Z", | |
"start_time": "2019-07-16T16:10:50.340292Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"X, y, beta = dask_setup(2, True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2019-07-16T16:12:00.140109Z", | |
"start_time": "2019-07-16T16:10:51.887189Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"8.47 s ± 532 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit\n", | |
"dask_compute(X, y, beta)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Dask. GPUs: 2 | sync: False" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2019-07-16T16:12:00.858834Z", | |
"start_time": "2019-07-16T16:12:00.142603Z" | |
}, | |
"scrolled": true | |
}, | |
"outputs": [], | |
"source": [ | |
"X, y, beta = dask_setup(2, False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2019-07-16T16:12:58.991122Z", | |
"start_time": "2019-07-16T16:12:00.861210Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"7.1 s ± 72.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit\n", | |
"dask_compute(X, y, beta)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Dask. GPUs: 1 | sync: True" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2019-07-16T16:12:59.669510Z", | |
"start_time": "2019-07-16T16:12:58.992900Z" | |
}, | |
"scrolled": true | |
}, | |
"outputs": [], | |
"source": [ | |
"X, y, beta = dask_setup(1, True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2019-07-16T16:15:02.912574Z", | |
"start_time": "2019-07-16T16:12:59.671177Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"15.3 s ± 29.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit\n", | |
"dask_compute(X, y, beta)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Dask. GPUs: 1 | sync: False" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2019-07-16T16:15:03.515171Z", | |
"start_time": "2019-07-16T16:15:02.914391Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"X, y, beta = dask_setup(1, False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2019-07-16T16:16:51.929743Z", | |
"start_time": "2019-07-16T16:15:03.516932Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"13.4 s ± 46.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit\n", | |
"dask_compute(X, y, beta)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.8" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment