Last active
January 23, 2017 00:28
-
-
Save crusaderky/62832a5ffc72ccb3e0954021b0996fdf to your computer and use it in GitHub Desktop.
xarray Fast Weighted Sum
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from collections import defaultdict\n", | |
"import dask.array\n", | |
"import numpy\n", | |
"import xarray\n", | |
"import random" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def fastwsum(arrays, weights, blocksize=64):\n", | |
" \"\"\"Weighted sum of arrays.\n", | |
"\n", | |
" :param arrays:\n", | |
" sequence of xarray.DataArray objects\n", | |
" :param weights:\n", | |
" sequence of scalars of the same length as arrays\n", | |
" :param blocksize:\n", | |
" number of arrays to add together at once in dask (see below)\n", | |
" :returns:\n", | |
" single xarray.DataArray\n", | |
"\n", | |
" this function is functionally equivalent to::\n", | |
"\n", | |
" sum(a * w for a, w in zip(arrays, weights))\n", | |
"\n", | |
" but it is potentially much faster because:\n", | |
"\n", | |
" - the xarray broadcast/align magic is executed once instead of being repeated for every\n", | |
" subtotal\n", | |
" - attempt to broadcast as late as possible by calculating subtotals by dims\n", | |
" - 1 and 0 weights are optimized away\n", | |
" - there is one dask operation every <blocksize> arrays, instead of two per array (one to\n", | |
" multiply by the weight and another to add to the subtotal).\n", | |
" - in case of mixed dask/numpy addends, the numpy-only subtotal is added only once to the\n", | |
" dask graph\n", | |
"\n", | |
" The downside is that <blocksize> inputs must be made available by dask at the same time,\n", | |
" with consecutive increase in RAM occupation and (in case of dask.distributed) potentially\n", | |
" higher-than-needed data transfers over the network.\n", | |
" \"\"\"\n", | |
" assert len(arrays) == len(weights)\n", | |
"\n", | |
" # Attempt to broadcast as late as possible by calculating subtotals by dims.\n", | |
" group_by_dims = defaultdict(lambda: ([], []))\n", | |
" for a, w in zip(arrays, weights):\n", | |
" group_by_dims[a.dims][0].append(a)\n", | |
" group_by_dims[a.dims][1].append(w)\n", | |
"\n", | |
" if len(group_by_dims) > 1 and any(len(v[0]) > 1 for v in group_by_dims.values()):\n", | |
" subtotals = [\n", | |
" fastwsum(arrays, weights, blocksize=blocksize)\n", | |
" for (arrays, weights) in group_by_dims.values()\n", | |
" ]\n", | |
" return fastsum(subtotals, blocksize=blocksize)\n", | |
"\n", | |
" arrays = xarray.broadcast(*arrays)\n", | |
"\n", | |
" numpy_total = 0\n", | |
" dask_data = []\n", | |
" dask_weights = []\n", | |
" for array, weight in zip(arrays, weights):\n", | |
" if weight == 0:\n", | |
" pass\n", | |
" try:\n", | |
" array.data.dask\n", | |
" # xarray with dask backend\n", | |
" dask_data.append(array.data)\n", | |
" dask_weights.append(weight)\n", | |
" except AttributeError:\n", | |
" # xarray with numpy backend or scalar\n", | |
" if weight == 1:\n", | |
" numpy_total += array.data\n", | |
" else:\n", | |
" numpy_total += array.data * weight\n", | |
"\n", | |
" if numpy_total is 0 and len(dask_data) == 0:\n", | |
" # All weights are 0\n", | |
" scalar_coords = {\n", | |
" k: v for k, v in arrays[0].coords.items()\n", | |
" if v.shape == ()\n", | |
" }\n", | |
" return xarray.DataArray(0, coords=scalar_coords)\n", | |
"\n", | |
" while len(dask_data) > 1 or (len(dask_weights) > 0 and dask_weights[0] != 1):\n", | |
" assert len(dask_data) == len(dask_weights)\n", | |
" subtotals = []\n", | |
" for offset in range(0, len(dask_data), blocksize):\n", | |
" dask_data_slice = dask_data[offset:offset + blocksize]\n", | |
" dask_weights_slice = dask_weights[offset:offset + blocksize]\n", | |
" dtype = numpy.result_type(*[d.dtype for d in dask_data_slice])\n", | |
" subtotals.append(dask.array.map_blocks(\n", | |
" _fastwsum_kernel, *dask_data_slice, weights=dask_weights_slice, dtype=dtype))\n", | |
" dask_data = subtotals\n", | |
" dask_weights = [1] * len(dask_data)\n", | |
"\n", | |
" if numpy_total is not 0:\n", | |
" dask_data.append(numpy_total)\n", | |
" assert len(dask_data) in (1, 2)\n", | |
"\n", | |
" return xarray.DataArray(sum(dask_data), dims=arrays[0].dims, coords=arrays[0].coords)\n", | |
"\n", | |
"\n", | |
"def _fastwsum_kernel(*arrays, weights):\n", | |
" total = 0\n", | |
" for array, weight in zip(arrays, weights):\n", | |
" if weight == 1:\n", | |
" total += array\n", | |
" else:\n", | |
" total += array * weight\n", | |
" return total\n", | |
"\n", | |
"\n", | |
"def fastsum(arrays, blocksize=64):\n", | |
" \"\"\"Functionally equivalent to sum(*args), but faster.\n", | |
" All arguments must be :class:`xarray.DataArray` objects.\n", | |
" All notes for :func:`fastwsum` apply.\n", | |
" \"\"\"\n", | |
" return fastwsum(arrays, [1] * len(arrays), blocksize)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Benchmarks" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"def rand_addend(use_dask):\n", | |
" if random.randint(0, 9) > 0:\n", | |
" if use_dask:\n", | |
" data = dask.array.random.random((3, 100), chunks=10)\n", | |
" else:\n", | |
" data = numpy.random.random(300).reshape(3, 100)\n", | |
" return xarray.DataArray(data, dims=['time', 'scenario'], coords={'time': ['A', 'B', 'C']})\n", | |
" else:\n", | |
" return xarray.DataArray([1.1, 2.2, 3.3], dims=['time'], coords={'time': ['A', 'B', 'C']})\n", | |
"\n", | |
"def rand_weight():\n", | |
" w = random.random()\n", | |
" if w < .1:\n", | |
" return 0\n", | |
" elif w < .2:\n", | |
" return 1\n", | |
" else:\n", | |
" return w\n", | |
" \n", | |
"\n", | |
"addends = [rand_addend(True) for _ in range(4000)]\n", | |
"weights = [rand_weight() for _ in range(4000)]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Mixed dask+numpy - plain sum" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 20.4 s, sys: 104 ms, total: 20.5 s\n", | |
"Wall time: 20.5 s\n", | |
"CPU times: user 14.6 s, sys: 1.08 s, total: 15.7 s\n", | |
"Wall time: 14.6 s\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"75712" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%time x = sum(addends)\n", | |
"%time x.compute()\n", | |
"len(x.data.dask)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 2.8 s, sys: 20 ms, total: 2.82 s\n", | |
"Wall time: 2.81 s\n", | |
"CPU times: user 8.87 s, sys: 760 ms, total: 9.63 s\n", | |
"Wall time: 8.86 s\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"72052" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%time x = fastsum(addends)\n", | |
"%time x.compute()\n", | |
"len(x.data.dask)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Mixed dask+numpy - weighted sum" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 32.5 s, sys: 56 ms, total: 32.5 s\n", | |
"Wall time: 32.6 s\n", | |
"CPU times: user 17.2 s, sys: 964 ms, total: 18.2 s\n", | |
"Wall time: 17.1 s\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"112092" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%time x = sum(a * w for a, w in zip(addends, weights))\n", | |
"%time x.compute()\n", | |
"len(x.data.dask)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 2.8 s, sys: 8 ms, total: 2.81 s\n", | |
"Wall time: 2.81 s\n", | |
"CPU times: user 8.93 s, sys: 728 ms, total: 9.66 s\n", | |
"Wall time: 8.87 s\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"72052" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%time x = fastwsum(addends, weights)\n", | |
"%time x.compute()\n", | |
"len(x.data.dask)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## numpy only" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"addends = [rand_addend(False) for _ in range(4000)]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 1.67 s, sys: 0 ns, total: 1.67 s\n", | |
"Wall time: 1.67 s\n" | |
] | |
} | |
], | |
"source": [ | |
"%time x = sum(a * w for a, w in zip(addends, weights))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 856 ms, sys: 4 ms, total: 860 ms\n", | |
"Wall time: 856 ms\n" | |
] | |
} | |
], | |
"source": [ | |
"%time x = fastwsum(addends, weights)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python [conda root]", | |
"language": "python", | |
"name": "conda-root-py" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment