Last active
April 29, 2020 19:32
-
-
Save gngdb/55d549372afa9c9b4dd61cecb412b3a4 to your computer and use it in GitHub Desktop.
I often come across the problem that I have an array of indexes I want to match to a elements along a dimension. It's solved by gather.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "A problem that seems to come up a lot, at least for me, is I have a array/tensor with some number of dimensions, I'll say two:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import numpy as np" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[ 1.1017507 , -1.14121454, 0.75730627, -0.61367408],\n", | |
| " [ 1.18171565, -2.78205682, -0.13073507, -1.20702314],\n", | |
| " [ 0.8001157 , -1.35673238, 1.51762854, -0.12859282],\n", | |
| " [-0.70555315, 0.3134686 , 0.31257787, 0.08894421],\n", | |
| " [ 0.59701958, -0.06539089, -0.98216347, 1.26345614],\n", | |
| " [-0.51548571, -1.23917265, 1.85099668, -0.51009503],\n", | |
| " [ 0.5139998 , -2.02818922, 1.37309187, -0.37839921],\n", | |
| " [-0.60952342, -0.80467214, -0.9038227 , -0.82328037],\n", | |
| " [-2.066902 , 0.26316633, -0.32242237, 1.11637655],\n", | |
| " [ 0.6413155 , -0.92080204, 1.02428943, 0.90786728]])" | |
| ] | |
| }, | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "x = np.random.randn(10,4)\n", | |
| "x" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "And I have a vector of indexes:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([3, 0, 0, 3, 2, 0, 1, 3, 3, 3])" | |
| ] | |
| }, | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "i = np.random.randint(0, 4, size=(10,))\n", | |
| "i" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "And I want to index the trailing dimension array/tensor `x` with the respective indexes `i`. A vectorised operation equivalent to:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "-0.6136740811801799\n", | |
| "1.1817156478930582\n", | |
| "0.80011569707947\n", | |
| "0.08894421483959186\n", | |
| "-0.9821634720158681\n", | |
| "-0.5154857103938508\n", | |
| "-2.0281892222359406\n", | |
| "-0.8232803664398745\n", | |
| "1.1163765542607507\n", | |
| "0.9078672770992279\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "for r, j in zip(x, i):\n", | |
| " print(r[j])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "One option is to pair the indexes in `i` with indexes indicating which row each should correspond to:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "[-0.61367408 1.18171565 0.8001157 0.08894421 -0.98216347 -0.51548571\n", | |
| " -2.02818922 -0.82328037 1.11637655 0.90786728]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "indexed = x[np.arange(10), i]\n", | |
| "print(indexed)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "That's a bit ugly though, and it feels wrong to populate an `np.arange` array to make this work. I thought maybe this was a job for `numpy.take`? But, that doesn't do what I expect:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[-0.61367408, 1.1017507 , 1.1017507 , -0.61367408, 0.75730627,\n", | |
| " 1.1017507 , -1.14121454, -0.61367408, -0.61367408, -0.61367408],\n", | |
| " [-1.20702314, 1.18171565, 1.18171565, -1.20702314, -0.13073507,\n", | |
| " 1.18171565, -2.78205682, -1.20702314, -1.20702314, -1.20702314],\n", | |
| " [-0.12859282, 0.8001157 , 0.8001157 , -0.12859282, 1.51762854,\n", | |
| " 0.8001157 , -1.35673238, -0.12859282, -0.12859282, -0.12859282],\n", | |
| " [ 0.08894421, -0.70555315, -0.70555315, 0.08894421, 0.31257787,\n", | |
| " -0.70555315, 0.3134686 , 0.08894421, 0.08894421, 0.08894421],\n", | |
| " [ 1.26345614, 0.59701958, 0.59701958, 1.26345614, -0.98216347,\n", | |
| " 0.59701958, -0.06539089, 1.26345614, 1.26345614, 1.26345614],\n", | |
| " [-0.51009503, -0.51548571, -0.51548571, -0.51009503, 1.85099668,\n", | |
| " -0.51548571, -1.23917265, -0.51009503, -0.51009503, -0.51009503],\n", | |
| " [-0.37839921, 0.5139998 , 0.5139998 , -0.37839921, 1.37309187,\n", | |
| " 0.5139998 , -2.02818922, -0.37839921, -0.37839921, -0.37839921],\n", | |
| " [-0.82328037, -0.60952342, -0.60952342, -0.82328037, -0.9038227 ,\n", | |
| " -0.60952342, -0.80467214, -0.82328037, -0.82328037, -0.82328037],\n", | |
| " [ 1.11637655, -2.066902 , -2.066902 , 1.11637655, -0.32242237,\n", | |
| " -2.066902 , 0.26316633, 1.11637655, 1.11637655, 1.11637655],\n", | |
| " [ 0.90786728, 0.6413155 , 0.6413155 , 0.90786728, 1.02428943,\n", | |
| " 0.6413155 , -0.92080204, 0.90786728, 0.90786728, 0.90786728]])" | |
| ] | |
| }, | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "np.take(x, i, axis=1)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "What about [`take_along_axis`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.take_along_axis.html)? That doesn't do what I expect either:" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[-0.61367408, 1.1017507 , 1.1017507 , -0.61367408, 0.75730627,\n", | |
| " 1.1017507 , -1.14121454, -0.61367408, -0.61367408, -0.61367408],\n", | |
| " [-1.20702314, 1.18171565, 1.18171565, -1.20702314, -0.13073507,\n", | |
| " 1.18171565, -2.78205682, -1.20702314, -1.20702314, -1.20702314],\n", | |
| " [-0.12859282, 0.8001157 , 0.8001157 , -0.12859282, 1.51762854,\n", | |
| " 0.8001157 , -1.35673238, -0.12859282, -0.12859282, -0.12859282],\n", | |
| " [ 0.08894421, -0.70555315, -0.70555315, 0.08894421, 0.31257787,\n", | |
| " -0.70555315, 0.3134686 , 0.08894421, 0.08894421, 0.08894421],\n", | |
| " [ 1.26345614, 0.59701958, 0.59701958, 1.26345614, -0.98216347,\n", | |
| " 0.59701958, -0.06539089, 1.26345614, 1.26345614, 1.26345614],\n", | |
| " [-0.51009503, -0.51548571, -0.51548571, -0.51009503, 1.85099668,\n", | |
| " -0.51548571, -1.23917265, -0.51009503, -0.51009503, -0.51009503],\n", | |
| " [-0.37839921, 0.5139998 , 0.5139998 , -0.37839921, 1.37309187,\n", | |
| " 0.5139998 , -2.02818922, -0.37839921, -0.37839921, -0.37839921],\n", | |
| " [-0.82328037, -0.60952342, -0.60952342, -0.82328037, -0.9038227 ,\n", | |
| " -0.60952342, -0.80467214, -0.82328037, -0.82328037, -0.82328037],\n", | |
| " [ 1.11637655, -2.066902 , -2.066902 , 1.11637655, -0.32242237,\n", | |
| " -2.066902 , 0.26316633, 1.11637655, 1.11637655, 1.11637655],\n", | |
| " [ 0.90786728, 0.6413155 , 0.6413155 , 0.90786728, 1.02428943,\n", | |
| " 0.6413155 , -0.92080204, 0.90786728, 0.90786728, 0.90786728]])" | |
| ] | |
| }, | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "np.take_along_axis(x, i.reshape(1,-1), axis=1)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "What about [`gather`](https://pytorch.org/docs/stable/torch.html?highlight=gather#torch.gather)?" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import torch\n", | |
| "X = torch.tensor(x)\n", | |
| "I = torch.tensor(i)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 19, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[-0.6137],\n", | |
| " [ 1.1817],\n", | |
| " [ 0.8001],\n", | |
| " [ 0.0889],\n", | |
| " [-0.9822],\n", | |
| " [-0.5155],\n", | |
| " [-2.0282],\n", | |
| " [-0.8233],\n", | |
| " [ 1.1164],\n", | |
| " [ 0.9079]], dtype=torch.float64)" | |
| ] | |
| }, | |
| "execution_count": 19, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "torch.gather(X, 1, I.view(10,1))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "That looks good, it's nice and simple. [Someone's also written it for numpy](https://stackoverflow.com/a/46204790):" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 20, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "def gather_numpy(self, dim, index):\n", | |
| " \"\"\"\n", | |
| " Gathers values along an axis specified by dim.\n", | |
| " For a 3-D tensor the output is specified by:\n", | |
| " out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0\n", | |
| " out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1\n", | |
| " out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2\n", | |
| "\n", | |
| " :param dim: The axis along which to index\n", | |
| " :param index: A tensor of indices of elements to gather\n", | |
| " :return: tensor of gathered values\n", | |
| " \"\"\"\n", | |
| " idx_xsection_shape = index.shape[:dim] + index.shape[dim + 1:]\n", | |
| " self_xsection_shape = self.shape[:dim] + self.shape[dim + 1:]\n", | |
| " if idx_xsection_shape != self_xsection_shape:\n", | |
| " raise ValueError(\"Except for dimension \" + str(dim) +\n", | |
| " \", all dimensions of index and self should be the same size\")\n", | |
| " if index.dtype != np.dtype('int_'):\n", | |
| " raise TypeError(\"The values of index must be integers\")\n", | |
| " data_swaped = np.swapaxes(self, 0, dim)\n", | |
| " index_swaped = np.swapaxes(index, 0, dim)\n", | |
| " gathered = np.choose(index_swaped, data_swaped)\n", | |
| " return np.swapaxes(gathered, 0, dim)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 22, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([[-0.61367408],\n", | |
| " [ 1.18171565],\n", | |
| " [ 0.8001157 ],\n", | |
| " [ 0.08894421],\n", | |
| " [-0.98216347],\n", | |
| " [-0.51548571],\n", | |
| " [-2.02818922],\n", | |
| " [-0.82328037],\n", | |
| " [ 1.11637655],\n", | |
| " [ 0.90786728]])" | |
| ] | |
| }, | |
| "execution_count": 22, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "gather_numpy(x, 1, i.reshape(10,1))" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.6.10" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment