Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Last active February 9, 2023 02:20
Show Gist options
  • Save AmosLewis/c90c1148a96291db93408b3fa39f9ae2 to your computer and use it in GitHub Desktop.
Save AmosLewis/c90c1148a96291db93408b3fa39f9ae2 to your computer and use it in GitHub Desktop.
import torch
t = torch.tensor([
[1, 2, 3, 4, 5],
[6,7,8,9,10],
[11,12,13,14,15],
[16,17,18,19,20]
]) # 4*5
i = torch.tensor([
[1,2,3],
[3,2,1],
]) # 2*3
o = t[i]
= torch.tensor([
[[ 6, 7, 8, 9, 10], [11, 12, 13, 14, 15], [16, 17, 18, 19, 20]],
[[16, 17, 18, 19, 20], [11, 12, 13, 14, 15], [ 6, 7, 8, 9, 10]]
]) # 2*3*5
########## same function by tensorflow
import tensorflow as tf
t = tf.constant([
[1, 2, 3, 4, 5],
[6,7,8,9,10],
[11,12,13,14,15],
[16,17,18,19,20]
]) # 4*5
i = tf.constant([
[1,2,3],
[3,2,1],
]) # 2*3
i_expand = tf.expand_dims(i,axis=2)
<tf.Tensor: shape=(2, 3, 1), dtype=int32, numpy=
array([[[1],
[2],
[3]],
[[3],
[2],
[1]]], dtype=int32)>
io=tf.gather_nd(t,tf.expand_dims(i,axis=2))
<tf.Tensor: shape=(2, 3, 5), dtype=int32, numpy=
array([[[ 6, 7, 8, 9, 10],
[11, 12, 13, 14, 15],
[16, 17, 18, 19, 20]],
[[16, 17, 18, 19, 20],
[11, 12, 13, 14, 15],
[ 6, 7, 8, 9, 10]]], dtype=int32)>
@AmosLewis
Copy link
Author

torch-mlir-opt -convert-torch-to-tosa /tmp/IndexTensorStaticModule.mlir -mlir-print-ir-after-all -mlir-disable-threading --mlir-print-ir-before-all --debug

@AmosLewis
Copy link
Author

AmosLewis commented Feb 9, 2023

multiple indexes

t = torch.tensor([
[1, 2, 3, 4, 5], 
[6,7,8,9,10],
[11,12,13,14,15],
[16,17,18,19,20]
]) # 4*5
i = torch.tensor([
[1,2,3], 
[3,2,1],
]) # 2*3


oo = t[i,i]
   = tensor([[ t[1,1], t[2,2], t[3,3]],
             [t[3,3], t[2,2],  t[1,1]]])
torch.ops.aten.index(t,(i,i))
   = tensor([[ 7, 13, 19],
             [19, 13,  7]])



oo = t[i,i,i] # error


import tensorflow as tf

t = tf.constant([
[1, 2, 3, 4, 5], 
[6,7,8,9,10],
[11,12,13,14,15],
[16,17,18,19,20]
]) # 4*5

i =  tf.constant([
[1,2,3], 
[3,2,1],
]) # 2*3

i_expand = tf.expand_dims(i,axis=2)

ii = tf.concat((i_expand,i_expand), dim=2)

ii =  tf.constant([
[[1,1],[2,2],[3,3]], 
[[3,3],[2,2],[1,1]],
]) # 2*3*2

iio=tf.gather_nd(t,ii)
>>> iio
<tf.Tensor: shape=(2, 3), dtype=int32, numpy=
array([[ 7, 13, 19],
       [19, 13,  7]], dtype=int32)>

@AmosLewis
Copy link
Author

Multi Indexes with different shape

###############################################################################3
>>> t = torch.arange(5*4*3).view(5,4,3)
>>> t
tensor([[[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8],
         [ 9, 10, 11]],

        [[12, 13, 14],
         [15, 16, 17],
         [18, 19, 20],
         [21, 22, 23]],

        [[24, 25, 26],
         [27, 28, 29],
         [30, 31, 32],
         [33, 34, 35]],

        [[36, 37, 38],
         [39, 40, 41],
         [42, 43, 44],
         [45, 46, 47]],

        [[48, 49, 50],
         [51, 52, 53],
         [54, 55, 56],
         [57, 58, 59]]])
>>> t.shape
torch.Size([5, 4, 3])
>>> i1 = torch.randint(0,3,(3,3))
>>> i1
tensor([[1, 2, 0],
        [0, 2, 0],
        [0, 0, 2]])
>>> i2 = torch.randint(0,3,(3,))
>>> i2
tensor([1, 1, 0])
>>> o = t[i1,i2]
>>> o
tensor([[[15, 16, 17],
         [27, 28, 29],
         [ 0,  1,  2]],

        [[ 3,  4,  5],
         [27, 28, 29],
         [ 0,  1,  2]],

        [[ 3,  4,  5],
         [ 3,  4,  5],
         [24, 25, 26]]])
tensor([[t[1,1], t[2,1], t[0,0]],

        [t[0,1], t[2,1], t[0,0]],

        [t[0,1], t[0,1], t[2,0]]])
>>> o.shape
torch.Size([3, 3, 3])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment