Created
January 13, 2021 22:59
-
-
Save calebrob6/e38c1fdc2590bedc0e531d8355846e39 to your computer and use it in GitHub Desktop.
Different implementations/benchmarking of calculating the mode over an axis of a ndarray.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import scipy.stats" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"num_layers = 100\n", | |
"height = 4000\n", | |
"width = 4000\n", | |
"max_val = 16\n", | |
"\n", | |
"values = np.random.randint(0, max_val, size=(num_layers,height,width))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## We want to find the most frequent value along the first axis of `values`" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"## Find the mode of the first index to check answers against\n", | |
"vals, counts = np.unique(values[:,0,0], return_counts=True)\n", | |
"assert_against = vals[np.argmax(counts)]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 1min 7s, sys: 619 ms, total: 1min 8s\n", | |
"Wall time: 1min 8s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"def fast_mode(arr, minlength):\n", | |
" counts = np.apply_along_axis(np.bincount, 0, values, minlength=max_val)\n", | |
" return counts.argmax(axis=0)\n", | |
"\n", | |
"mode = fast_mode(values, max_val)\n", | |
"assert assert_against == mode[0,0]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 1min 30s, sys: 287 ms, total: 1min 30s\n", | |
"Wall time: 1min 30s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"def fast_mode(arr, minlength):\n", | |
" def f(arr):\n", | |
" return np.bincount(arr, minlength=max_val).argmax()\n", | |
" return np.apply_along_axis(f, 0, values)\n", | |
"\n", | |
"mode = fast_mode(values, max_val)\n", | |
"assert assert_against == mode[0,0]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 1min 7s, sys: 8.58 s, total: 1min 16s\n", | |
"Wall time: 1min 16s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"def fast_mode(arr, minlength):\n", | |
" counts = np.zeros((minlength,) + arr.shape[1:])\n", | |
" for i in range(minlength):\n", | |
" counts[i] += (arr == i).sum(0)\n", | |
" return counts.argmax(axis=0)\n", | |
"\n", | |
"mode = fast_mode(values, max_val)\n", | |
"assert assert_against == mode[0,0]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 1min 33s, sys: 26.2 s, total: 1min 59s\n", | |
"Wall time: 1min 59s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"def fast_mode(ndarray, axis=0):\n", | |
" '''https://stackoverflow.com/questions/16330831/most-efficient-way-to-find-mode-in-numpy-array'''\n", | |
" # Check inputs\n", | |
" ndarray = np.asarray(ndarray)\n", | |
" ndim = ndarray.ndim\n", | |
" if ndarray.size == 1:\n", | |
" return (ndarray[0], 1)\n", | |
" elif ndarray.size == 0:\n", | |
" raise Exception('Cannot compute mode on empty array')\n", | |
" try:\n", | |
" axis = range(ndarray.ndim)[axis]\n", | |
" except:\n", | |
" raise Exception('Axis \"{}\" incompatible with the {}-dimension array'.format(axis, ndim))\n", | |
"\n", | |
" # If array is 1-D and np version is > 1.9 np.unique will suffice\n", | |
" if all([ndim == 1,\n", | |
" int(np.__version__.split('.')[0]) >= 1,\n", | |
" int(np.__version__.split('.')[1]) >= 9]):\n", | |
" modals, counts = np.unique(ndarray, return_counts=True)\n", | |
" index = np.argmax(counts)\n", | |
" return modals[index], counts[index]\n", | |
"\n", | |
" # Sort array\n", | |
" sort = np.sort(ndarray, axis=axis)\n", | |
" # Create array to transpose along the axis and get padding shape\n", | |
" transpose = np.roll(np.arange(ndim)[::-1], axis)\n", | |
" shape = list(sort.shape)\n", | |
" shape[axis] = 1\n", | |
" # Create a boolean array along strides of unique values\n", | |
" strides = np.concatenate([np.zeros(shape=shape, dtype='bool'),\n", | |
" np.diff(sort, axis=axis) == 0,\n", | |
" np.zeros(shape=shape, dtype='bool')],\n", | |
" axis=axis).transpose(transpose).ravel()\n", | |
" # Count the stride lengths\n", | |
" counts = np.cumsum(strides)\n", | |
" counts[~strides] = np.concatenate([[0], np.diff(counts[~strides])])\n", | |
" counts[strides] = 0\n", | |
" # Get shape of padded counts and slice to return to the original shape\n", | |
" shape = np.array(sort.shape)\n", | |
" shape[axis] += 1\n", | |
" shape = shape[transpose]\n", | |
" slices = [slice(None)] * ndim\n", | |
" slices[axis] = slice(1, None)\n", | |
" # Reshape and compute final counts\n", | |
" counts = counts.reshape(shape).transpose(transpose)[tuple(slices)] + 1\n", | |
"\n", | |
" # Find maximum counts and return modals/counts\n", | |
" slices = [slice(None, i) for i in sort.shape]\n", | |
" del slices[axis]\n", | |
" index = np.ogrid[slices]\n", | |
" index.insert(axis, np.argmax(counts, axis=axis))\n", | |
" index = tuple(index)\n", | |
" return sort[index], counts[index]\n", | |
"\n", | |
"mode, counts = fast_mode(values, axis=0)\n", | |
"assert assert_against == mode[0,0]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 9min 37s, sys: 156 ms, total: 9min 38s\n", | |
"Wall time: 9min 38s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"mode = scipy.stats.mode(values, axis=0)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3.5", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment