Skip to content

Instantly share code, notes, and snippets.

@Kulbear
Created June 21, 2022 08:45
Show Gist options
  • Save Kulbear/b421c32fb64deddfd0403340db4eeeaa to your computer and use it in GitHub Desktop.
Save Kulbear/b421c32fb64deddfd0403340db4eeeaa to your computer and use it in GitHub Desktop.
A PyTorch porting of tensorflow.gather_nd with batch_dim supported.
import torch
import tensorflow as tf
import time
import numpy as np
def gather_nd_torch(params, indices, batch_dim=1):
""" A PyTorch porting of tensorflow.gather_nd
This implementation can handle leading batch dimensions in params, see below for detailed explanation.
The majority of this implementation is from Michael Jungo @ https://stackoverflow.com/a/61810047/6670143
I just ported it compatible to leading batch dimension.
Args:
params: a tensor of dimension [b1, ..., bn, g1, ..., gm, c].
indices: a tensor of dimension [b1, ..., bn, x, m]
batch_dim: indicate how many batch dimension you have, in the above example, batch_dim = n.
Returns:
gathered: a tensor of dimension [b1, ..., bn, x, c].
Example:
>>> batch_size = 5
>>> inputs = torch.randn(batch_size, batch_size, batch_size, 4, 4, 4, 32)
>>> pos = torch.randint(4, (batch_size, batch_size, batch_size, 12, 3))
>>> gathered = gather_nd_torch(inputs, pos, batch_dim=3)
>>> gathered.shape
torch.Size([5, 5, 5, 12, 32])
>>> inputs_tf = tf.convert_to_tensor(inputs.numpy())
>>> pos_tf = tf.convert_to_tensor(pos.numpy())
>>> gathered_tf = tf.gather_nd(inputs_tf, pos_tf, batch_dims=3)
>>> gathered_tf.shape
TensorShape([5, 5, 5, 12, 32])
>>> gathered_tf = torch.from_numpy(gathered_tf.numpy())
>>> torch.equal(gathered_tf, gathered)
True
"""
batch_dims = params.size()[:batch_dim] # [b1, ..., bn]
batch_size = np.cumprod(list(batch_dims))[-1] # b1 * ... * bn
c_dim = params.size()[-1] # c
grid_dims = params.size()[batch_dim:-1] # [g1, ..., gm]
n_indices = indices.size(-2) # x
n_pos = indices.size(-1) # m
# reshape leadning batch dims to a single batch dim
params = params.reshape(batch_size, *grid_dims, c_dim)
indices = indices.reshape(batch_size, n_indices, n_pos)
# build gather indices
# gather for each of the data point in this "batch"
batch_enumeration = torch.arange(batch_size).unsqueeze(1)
gather_dims = [indices[:, :, i] for i in range(len(grid_dims))]
gather_dims.insert(0, batch_enumeration)
gathered = params[gather_dims]
# reshape back to the shape with leading batch dims
gathered = gathered.reshape(*batch_dims, n_indices, c_dim)
return gathered
batch_size = 5
inputs = torch.randn(batch_size, batch_size, batch_size, 4, 4, 4, 32)
pos = torch.randint(4, (batch_size, batch_size, batch_size, 12, 3))
s = time.time()
for i in range(1000):
gathered = gather_nd_torch(inputs, pos, batch_dim=3)
time_used = (time.time() - s) / 1000
print('Torch time used:', time_used)
# Again, verify that it's identical to TensorFlow's output
inputs_tf = tf.convert_to_tensor(inputs.numpy())
pos_tf = tf.convert_to_tensor(pos.numpy())
# This time with batch_dims=1
s = time.time()
for i in range(1000):
gathered_tf = tf.gather_nd(inputs_tf, pos_tf, batch_dims=3)
time_used = (time.time() - s) / 1000
print('TF time used:', time_used)
gathered_tf = torch.from_numpy(gathered_tf.numpy())
torch.equal(gathered_tf, gathered) # => True
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment