Skip to content

Instantly share code, notes, and snippets.

@gngdb
Last active April 29, 2020 19:32
Show Gist options
  • Select an option

  • Save gngdb/55d549372afa9c9b4dd61cecb412b3a4 to your computer and use it in GitHub Desktop.

Select an option

Save gngdb/55d549372afa9c9b4dd61cecb412b3a4 to your computer and use it in GitHub Desktop.
I often come across the problem that I have an array of indexes I want to match to a elements along a dimension. It's solved by gather.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A problem that seems to come up a lot, at least for me, is I have a array/tensor with some number of dimensions, I'll say two:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 1.1017507 , -1.14121454, 0.75730627, -0.61367408],\n",
" [ 1.18171565, -2.78205682, -0.13073507, -1.20702314],\n",
" [ 0.8001157 , -1.35673238, 1.51762854, -0.12859282],\n",
" [-0.70555315, 0.3134686 , 0.31257787, 0.08894421],\n",
" [ 0.59701958, -0.06539089, -0.98216347, 1.26345614],\n",
" [-0.51548571, -1.23917265, 1.85099668, -0.51009503],\n",
" [ 0.5139998 , -2.02818922, 1.37309187, -0.37839921],\n",
" [-0.60952342, -0.80467214, -0.9038227 , -0.82328037],\n",
" [-2.066902 , 0.26316633, -0.32242237, 1.11637655],\n",
" [ 0.6413155 , -0.92080204, 1.02428943, 0.90786728]])"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = np.random.randn(10,4)\n",
"x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And I have a vector of indexes:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([3, 0, 0, 3, 2, 0, 1, 3, 3, 3])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"i = np.random.randint(0, 4, size=(10,))\n",
"i"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And I want to index the trailing dimension array/tensor `x` with the respective indexes `i`. A vectorised operation equivalent to:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-0.6136740811801799\n",
"1.1817156478930582\n",
"0.80011569707947\n",
"0.08894421483959186\n",
"-0.9821634720158681\n",
"-0.5154857103938508\n",
"-2.0281892222359406\n",
"-0.8232803664398745\n",
"1.1163765542607507\n",
"0.9078672770992279\n"
]
}
],
"source": [
"for r, j in zip(x, i):\n",
" print(r[j])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"One option is to pair the indexes in `i` with indexes indicating which row each should correspond to:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[-0.61367408 1.18171565 0.8001157 0.08894421 -0.98216347 -0.51548571\n",
" -2.02818922 -0.82328037 1.11637655 0.90786728]\n"
]
}
],
"source": [
"indexed = x[np.arange(10), i]\n",
"print(indexed)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"That's a bit ugly though, and it feels wrong to populate an `np.arange` array to make this work. I thought maybe this was a job for `numpy.take`? But, that doesn't do what I expect:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[-0.61367408, 1.1017507 , 1.1017507 , -0.61367408, 0.75730627,\n",
" 1.1017507 , -1.14121454, -0.61367408, -0.61367408, -0.61367408],\n",
" [-1.20702314, 1.18171565, 1.18171565, -1.20702314, -0.13073507,\n",
" 1.18171565, -2.78205682, -1.20702314, -1.20702314, -1.20702314],\n",
" [-0.12859282, 0.8001157 , 0.8001157 , -0.12859282, 1.51762854,\n",
" 0.8001157 , -1.35673238, -0.12859282, -0.12859282, -0.12859282],\n",
" [ 0.08894421, -0.70555315, -0.70555315, 0.08894421, 0.31257787,\n",
" -0.70555315, 0.3134686 , 0.08894421, 0.08894421, 0.08894421],\n",
" [ 1.26345614, 0.59701958, 0.59701958, 1.26345614, -0.98216347,\n",
" 0.59701958, -0.06539089, 1.26345614, 1.26345614, 1.26345614],\n",
" [-0.51009503, -0.51548571, -0.51548571, -0.51009503, 1.85099668,\n",
" -0.51548571, -1.23917265, -0.51009503, -0.51009503, -0.51009503],\n",
" [-0.37839921, 0.5139998 , 0.5139998 , -0.37839921, 1.37309187,\n",
" 0.5139998 , -2.02818922, -0.37839921, -0.37839921, -0.37839921],\n",
" [-0.82328037, -0.60952342, -0.60952342, -0.82328037, -0.9038227 ,\n",
" -0.60952342, -0.80467214, -0.82328037, -0.82328037, -0.82328037],\n",
" [ 1.11637655, -2.066902 , -2.066902 , 1.11637655, -0.32242237,\n",
" -2.066902 , 0.26316633, 1.11637655, 1.11637655, 1.11637655],\n",
" [ 0.90786728, 0.6413155 , 0.6413155 , 0.90786728, 1.02428943,\n",
" 0.6413155 , -0.92080204, 0.90786728, 0.90786728, 0.90786728]])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.take(x, i, axis=1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"What about [`take_along_axis`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.take_along_axis.html)? That doesn't do what I expect either:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[-0.61367408, 1.1017507 , 1.1017507 , -0.61367408, 0.75730627,\n",
" 1.1017507 , -1.14121454, -0.61367408, -0.61367408, -0.61367408],\n",
" [-1.20702314, 1.18171565, 1.18171565, -1.20702314, -0.13073507,\n",
" 1.18171565, -2.78205682, -1.20702314, -1.20702314, -1.20702314],\n",
" [-0.12859282, 0.8001157 , 0.8001157 , -0.12859282, 1.51762854,\n",
" 0.8001157 , -1.35673238, -0.12859282, -0.12859282, -0.12859282],\n",
" [ 0.08894421, -0.70555315, -0.70555315, 0.08894421, 0.31257787,\n",
" -0.70555315, 0.3134686 , 0.08894421, 0.08894421, 0.08894421],\n",
" [ 1.26345614, 0.59701958, 0.59701958, 1.26345614, -0.98216347,\n",
" 0.59701958, -0.06539089, 1.26345614, 1.26345614, 1.26345614],\n",
" [-0.51009503, -0.51548571, -0.51548571, -0.51009503, 1.85099668,\n",
" -0.51548571, -1.23917265, -0.51009503, -0.51009503, -0.51009503],\n",
" [-0.37839921, 0.5139998 , 0.5139998 , -0.37839921, 1.37309187,\n",
" 0.5139998 , -2.02818922, -0.37839921, -0.37839921, -0.37839921],\n",
" [-0.82328037, -0.60952342, -0.60952342, -0.82328037, -0.9038227 ,\n",
" -0.60952342, -0.80467214, -0.82328037, -0.82328037, -0.82328037],\n",
" [ 1.11637655, -2.066902 , -2.066902 , 1.11637655, -0.32242237,\n",
" -2.066902 , 0.26316633, 1.11637655, 1.11637655, 1.11637655],\n",
" [ 0.90786728, 0.6413155 , 0.6413155 , 0.90786728, 1.02428943,\n",
" 0.6413155 , -0.92080204, 0.90786728, 0.90786728, 0.90786728]])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.take_along_axis(x, i.reshape(1,-1), axis=1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"What about [`gather`](https://pytorch.org/docs/stable/torch.html?highlight=gather#torch.gather)?"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"X = torch.tensor(x)\n",
"I = torch.tensor(i)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-0.6137],\n",
" [ 1.1817],\n",
" [ 0.8001],\n",
" [ 0.0889],\n",
" [-0.9822],\n",
" [-0.5155],\n",
" [-2.0282],\n",
" [-0.8233],\n",
" [ 1.1164],\n",
" [ 0.9079]], dtype=torch.float64)"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.gather(X, 1, I.view(10,1))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"That looks good, it's nice and simple. [Someone's also written it for numpy](https://stackoverflow.com/a/46204790):"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"def gather_numpy(self, dim, index):\n",
" \"\"\"\n",
" Gathers values along an axis specified by dim.\n",
" For a 3-D tensor the output is specified by:\n",
" out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0\n",
" out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1\n",
" out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2\n",
"\n",
" :param dim: The axis along which to index\n",
" :param index: A tensor of indices of elements to gather\n",
" :return: tensor of gathered values\n",
" \"\"\"\n",
" idx_xsection_shape = index.shape[:dim] + index.shape[dim + 1:]\n",
" self_xsection_shape = self.shape[:dim] + self.shape[dim + 1:]\n",
" if idx_xsection_shape != self_xsection_shape:\n",
" raise ValueError(\"Except for dimension \" + str(dim) +\n",
" \", all dimensions of index and self should be the same size\")\n",
" if index.dtype != np.dtype('int_'):\n",
" raise TypeError(\"The values of index must be integers\")\n",
" data_swaped = np.swapaxes(self, 0, dim)\n",
" index_swaped = np.swapaxes(index, 0, dim)\n",
" gathered = np.choose(index_swaped, data_swaped)\n",
" return np.swapaxes(gathered, 0, dim)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[-0.61367408],\n",
" [ 1.18171565],\n",
" [ 0.8001157 ],\n",
" [ 0.08894421],\n",
" [-0.98216347],\n",
" [-0.51548571],\n",
" [-2.02818922],\n",
" [-0.82328037],\n",
" [ 1.11637655],\n",
" [ 0.90786728]])"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gather_numpy(x, 1, i.reshape(10,1))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment