Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Last active February 9, 2023 02:20
Show Gist options
  • Save AmosLewis/c90c1148a96291db93408b3fa39f9ae2 to your computer and use it in GitHub Desktop.
Save AmosLewis/c90c1148a96291db93408b3fa39f9ae2 to your computer and use it in GitHub Desktop.
import torch
t = torch.tensor([
[1, 2, 3, 4, 5],
[6,7,8,9,10],
[11,12,13,14,15],
[16,17,18,19,20]
]) # 4*5
i = torch.tensor([
[1,2,3],
[3,2,1],
]) # 2*3
o = t[i]
= torch.tensor([
[[ 6, 7, 8, 9, 10], [11, 12, 13, 14, 15], [16, 17, 18, 19, 20]],
[[16, 17, 18, 19, 20], [11, 12, 13, 14, 15], [ 6, 7, 8, 9, 10]]
]) # 2*3*5
########## same function by tensorflow
import tensorflow as tf
t = tf.constant([
[1, 2, 3, 4, 5],
[6,7,8,9,10],
[11,12,13,14,15],
[16,17,18,19,20]
]) # 4*5
i = tf.constant([
[1,2,3],
[3,2,1],
]) # 2*3
i_expand = tf.expand_dims(i,axis=2)
<tf.Tensor: shape=(2, 3, 1), dtype=int32, numpy=
array([[[1],
[2],
[3]],
[[3],
[2],
[1]]], dtype=int32)>
io=tf.gather_nd(t,tf.expand_dims(i,axis=2))
<tf.Tensor: shape=(2, 3, 5), dtype=int32, numpy=
array([[[ 6, 7, 8, 9, 10],
[11, 12, 13, 14, 15],
[16, 17, 18, 19, 20]],
[[16, 17, 18, 19, 20],
[11, 12, 13, 14, 15],
[ 6, 7, 8, 9, 10]]], dtype=int32)>
@AmosLewis
Copy link
Author

AmosLewis commented Jan 16, 2023

#loc = loc(unknown)
module attributes {torch.debug_module_name = "IndexTensorStaticModule"} {
  func.func @forward(%arg0: !torch.vtensor<[4,5],f32> loc(unknown), %arg1: !torch.vtensor<[2,3],si64> loc(unknown)) -> !torch.vtensor<[2,3,5],f32> {
    %0 = torch.prim.ListConstruct %arg1 : (!torch.vtensor<[2,3],si64>) -> !torch.list<optional<vtensor>> loc(#loc1)
    %1 = torch.aten.index.Tensor %arg0, %0 : !torch.vtensor<[4,5],f32>, !torch.list<optional<vtensor>> -> !torch.vtensor<[2,3,5],f32> loc(#loc1)
    return %1 : !torch.vtensor<[2,3,5],f32> loc(#loc)
  } loc(#loc)
} loc(#loc)
#loc1 = loc("/home/chi/src/ubuntu20/shark/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/test_suite/basic.py":1803:15)

TO

module attributes {torch.debug_module_name = "IndexTensorStaticModule"} {
  func.func @forward(%arg0: !torch.vtensor<[4,5],f32>, %arg1: !torch.vtensor<[2,3],si64>) -> !torch.vtensor<[2,3,5],f32> {
    %0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[4,5],f32> -> tensor<4x5xf32>
    %1 = torch.prim.ListConstruct %arg1 : (!torch.vtensor<[2,3],si64>) -> !torch.list<optional<vtensor>>
    %2 = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[2,3],si64> -> tensor<2x3xi64>
    %3 = "tosa.cast"(%2) : (tensor<2x3xi64>) -> tensor<2x3xi32>
    %4 = "tosa.reshape"(%3) {new_shape = array<i64: 2, 3, 1>} : (tensor<2x3xi32>) -> tensor<2x3x1xi32>
    %5 = "tosa.reshape"(%0) {new_shape = array<i64: 1, 4, 5>} : (tensor<4x5xf32>) -> tensor<1x4x5xf32>
    %6 = "tosa.reshape"(%4) {new_shape = array<i64: 6, 1>} : (tensor<2x3x1xi32>) -> tensor<6x1xi32>
    %7 = "tosa.const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
    %8 = "tosa.mul"(%6, %7) {shift = 0 : i32} : (tensor<6x1xi32>, tensor<1xi32>) -> tensor<6x1xi32>
    %9 = "tosa.reduce_sum"(%8) {axis = 1 : i64} : (tensor<6x1xi32>) -> tensor<6x1xi32>
    %10 = "tosa.reshape"(%9) {new_shape = array<i64: 1, 6>} : (tensor<6x1xi32>) -> tensor<1x6xi32>
    %11 = "tosa.gather"(%5, %10) : (tensor<1x4x5xf32>, tensor<1x6xi32>) -> tensor<1x6x5xf32>
    %12 = "tosa.reshape"(%11) {new_shape = array<i64: 2, 3, 5>} : (tensor<1x6x5xf32>) -> tensor<2x3x5xf32>
    %13 = torch_c.from_builtin_tensor %12 : tensor<2x3x5xf32> -> !torch.vtensor<[2,3,5],f32>
    return %13 : !torch.vtensor<[2,3,5],f32>
  }
}

@AmosLewis
Copy link
Author

torch-mlir-opt -convert-torch-to-tosa /tmp/IndexTensorStaticModule.mlir -mlir-print-ir-after-all -mlir-disable-threading --mlir-print-ir-before-all --debug

@AmosLewis
Copy link
Author

AmosLewis commented Feb 9, 2023

multiple indexes

t = torch.tensor([
[1, 2, 3, 4, 5], 
[6,7,8,9,10],
[11,12,13,14,15],
[16,17,18,19,20]
]) # 4*5
i = torch.tensor([
[1,2,3], 
[3,2,1],
]) # 2*3


oo = t[i,i]
   = tensor([[ t[1,1], t[2,2], t[3,3]],
             [t[3,3], t[2,2],  t[1,1]]])
torch.ops.aten.index(t,(i,i))
   = tensor([[ 7, 13, 19],
             [19, 13,  7]])



oo = t[i,i,i] # error


import tensorflow as tf

t = tf.constant([
[1, 2, 3, 4, 5], 
[6,7,8,9,10],
[11,12,13,14,15],
[16,17,18,19,20]
]) # 4*5

i =  tf.constant([
[1,2,3], 
[3,2,1],
]) # 2*3

i_expand = tf.expand_dims(i,axis=2)

ii = tf.concat((i_expand,i_expand), dim=2)

ii =  tf.constant([
[[1,1],[2,2],[3,3]], 
[[3,3],[2,2],[1,1]],
]) # 2*3*2

iio=tf.gather_nd(t,ii)
>>> iio
<tf.Tensor: shape=(2, 3), dtype=int32, numpy=
array([[ 7, 13, 19],
       [19, 13,  7]], dtype=int32)>

@AmosLewis
Copy link
Author

Multi Indexes with different shape

###############################################################################3
>>> t = torch.arange(5*4*3).view(5,4,3)
>>> t
tensor([[[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8],
         [ 9, 10, 11]],

        [[12, 13, 14],
         [15, 16, 17],
         [18, 19, 20],
         [21, 22, 23]],

        [[24, 25, 26],
         [27, 28, 29],
         [30, 31, 32],
         [33, 34, 35]],

        [[36, 37, 38],
         [39, 40, 41],
         [42, 43, 44],
         [45, 46, 47]],

        [[48, 49, 50],
         [51, 52, 53],
         [54, 55, 56],
         [57, 58, 59]]])
>>> t.shape
torch.Size([5, 4, 3])
>>> i1 = torch.randint(0,3,(3,3))
>>> i1
tensor([[1, 2, 0],
        [0, 2, 0],
        [0, 0, 2]])
>>> i2 = torch.randint(0,3,(3,))
>>> i2
tensor([1, 1, 0])
>>> o = t[i1,i2]
>>> o
tensor([[[15, 16, 17],
         [27, 28, 29],
         [ 0,  1,  2]],

        [[ 3,  4,  5],
         [27, 28, 29],
         [ 0,  1,  2]],

        [[ 3,  4,  5],
         [ 3,  4,  5],
         [24, 25, 26]]])
tensor([[t[1,1], t[2,1], t[0,0]],

        [t[0,1], t[2,1], t[0,0]],

        [t[0,1], t[0,1], t[2,0]]])
>>> o.shape
torch.Size([3, 3, 3])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment