Created
June 21, 2022 08:45
-
-
Save Kulbear/b421c32fb64deddfd0403340db4eeeaa to your computer and use it in GitHub Desktop.
A PyTorch porting of tensorflow.gather_nd with batch_dim supported.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import tensorflow as tf | |
import time | |
import numpy as np | |
def gather_nd_torch(params, indices, batch_dim=1): | |
""" A PyTorch porting of tensorflow.gather_nd | |
This implementation can handle leading batch dimensions in params, see below for detailed explanation. | |
The majority of this implementation is from Michael Jungo @ https://stackoverflow.com/a/61810047/6670143 | |
I just ported it compatible to leading batch dimension. | |
Args: | |
params: a tensor of dimension [b1, ..., bn, g1, ..., gm, c]. | |
indices: a tensor of dimension [b1, ..., bn, x, m] | |
batch_dim: indicate how many batch dimension you have, in the above example, batch_dim = n. | |
Returns: | |
gathered: a tensor of dimension [b1, ..., bn, x, c]. | |
Example: | |
>>> batch_size = 5 | |
>>> inputs = torch.randn(batch_size, batch_size, batch_size, 4, 4, 4, 32) | |
>>> pos = torch.randint(4, (batch_size, batch_size, batch_size, 12, 3)) | |
>>> gathered = gather_nd_torch(inputs, pos, batch_dim=3) | |
>>> gathered.shape | |
torch.Size([5, 5, 5, 12, 32]) | |
>>> inputs_tf = tf.convert_to_tensor(inputs.numpy()) | |
>>> pos_tf = tf.convert_to_tensor(pos.numpy()) | |
>>> gathered_tf = tf.gather_nd(inputs_tf, pos_tf, batch_dims=3) | |
>>> gathered_tf.shape | |
TensorShape([5, 5, 5, 12, 32]) | |
>>> gathered_tf = torch.from_numpy(gathered_tf.numpy()) | |
>>> torch.equal(gathered_tf, gathered) | |
True | |
""" | |
batch_dims = params.size()[:batch_dim] # [b1, ..., bn] | |
batch_size = np.cumprod(list(batch_dims))[-1] # b1 * ... * bn | |
c_dim = params.size()[-1] # c | |
grid_dims = params.size()[batch_dim:-1] # [g1, ..., gm] | |
n_indices = indices.size(-2) # x | |
n_pos = indices.size(-1) # m | |
# reshape leadning batch dims to a single batch dim | |
params = params.reshape(batch_size, *grid_dims, c_dim) | |
indices = indices.reshape(batch_size, n_indices, n_pos) | |
# build gather indices | |
# gather for each of the data point in this "batch" | |
batch_enumeration = torch.arange(batch_size).unsqueeze(1) | |
gather_dims = [indices[:, :, i] for i in range(len(grid_dims))] | |
gather_dims.insert(0, batch_enumeration) | |
gathered = params[gather_dims] | |
# reshape back to the shape with leading batch dims | |
gathered = gathered.reshape(*batch_dims, n_indices, c_dim) | |
return gathered | |
batch_size = 5 | |
inputs = torch.randn(batch_size, batch_size, batch_size, 4, 4, 4, 32) | |
pos = torch.randint(4, (batch_size, batch_size, batch_size, 12, 3)) | |
s = time.time() | |
for i in range(1000): | |
gathered = gather_nd_torch(inputs, pos, batch_dim=3) | |
time_used = (time.time() - s) / 1000 | |
print('Torch time used:', time_used) | |
# Again, verify that it's identical to TensorFlow's output | |
inputs_tf = tf.convert_to_tensor(inputs.numpy()) | |
pos_tf = tf.convert_to_tensor(pos.numpy()) | |
# This time with batch_dims=1 | |
s = time.time() | |
for i in range(1000): | |
gathered_tf = tf.gather_nd(inputs_tf, pos_tf, batch_dims=3) | |
time_used = (time.time() - s) / 1000 | |
print('TF time used:', time_used) | |
gathered_tf = torch.from_numpy(gathered_tf.numpy()) | |
torch.equal(gathered_tf, gathered) # => True |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment