Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Last active December 16, 2022 07:36
Show Gist options
  • Save AmosLewis/bb6e3a0ad9fd1705c9f9d42a2eefbb88 to your computer and use it in GitHub Desktop.
Save AmosLewis/bb6e3a0ad9fd1705c9f9d42a2eefbb88 to your computer and use it in GitHub Desktop.
module {
func.func @torch.aten.gather(%arg0: !torch.vtensor<[1,4,3],f32>, %arg1: !torch.vtensor<[1,4,2,3],i32>) -> !torch.vtensor<[1,4,2],f32> {
%0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,4,3],f32> -> tensor<1x4x3xf32>
%0= torch.tensor([[
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10, 11, 12]
]]) # 1*4*3
%1 = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[1,4,2,3],i32> -> tensor<1x4x2x3xi32>
%1=torch.tensor([[
[[0, 0, 0],
[0, 0, 0]],
[[0, 1, 1],
[0, 1, 0]],
[[0, 2, 2],
[0, 2, 1]],
[[0, 3, 2],
[0, 3, 1]]
]]) # 1*4*2*3
%int-1 = torch.constant.int -1
%false = torch.constant.bool false
%2 = "tosa.reshape"(%0) {new_shape = [1, 12, 1]} : (tensor<1x4x3xf32>) -> tensor<1x12x1xf32>
%2=torch.tensor([[
[1],
[2],
[3],
[4],
[5],
[6],
[7],
[8],
[9],
[10],
[11],
[12]
]]) # 1*12*1
%3 = "tosa.reshape"(%1) {new_shape = [8, 3]} : (tensor<1x4x2x3xi32>) -> tensor<8x3xi32>
%3=torch.tensor([
[0, 0, 0],
[0, 0, 0],
[0, 1, 1],
[0, 1, 0],
[0, 2, 2],
[0, 2, 1],
[0, 3, 2],
[0, 3, 1]
]) # 8*3
%4 = "tosa.const"() {value = dense<[12, 3, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
%3=torch.tensor([12, 3, 1]) #3*1
%5 = "tosa.mul"(%3, %4) {shift = 0 : i32} : (tensor<8x3xi32>, tensor<3xi32>) -> tensor<8x3xi32>
%5=torch.tensor([
[12*0, 3*0, 1*0],
[12*0, 3*0, 1*0],
[12*0, 3*1, 1*1],
[12*0, 3*1, 1*0],
[12*0, 3*2, 1*2],
[12*0, 3*2, 1*1],
[12*0, 3*3, 1*2],
[12*0, 3*3, 1*1]
]) # 8*3
%6 = "tosa.reduce_sum"(%5) {axis = 1 : i64} : (tensor<8x3xi32>) -> tensor<?x?xi32>
%6=torch.tensor([
[12*0 + 3*0 + 1*0],
[12*0 + 3*0 + 1*0],
[12*0 + 3*1 + 1*1],
[12*0 + 3*1 + 1*0],
[12*0 + 3*2 + 1*2],
[12*0 + 3*2 + 1*1],
[12*0 + 3*3 + 1*2],
[12*0 + 3*3 + 1*1]
]) # 8*1
%7 = "tosa.reshape"(%6) {new_shape = [1, 8]} : (tensor<?x?xi32>) -> tensor<1x8xi32>
%7=torch.tensor([
[12*0 + 3*0 + 1*0,
12*0 + 3*0 + 1*0,
12*0 + 3*1 + 1*1,
12*0 + 3*1 + 1*0,
12*0 + 3*2 + 1*2,
12*0 + 3*2 + 1*1,
12*0 + 3*3 + 1*2,
12*0 + 3*3 + 1*1]
]) # 1*8
%8 = "tosa.cast"(%7) : (tensor<1x8xi64>) -> tensor<1x8xi32>
%8=torch.tensor([
[12*0 + 3*0 + 1*0,=0
12*0 + 3*0 + 1*0,=0
12*0 + 3*1 + 1*1,=4
12*0 + 3*1 + 1*0,=3
12*0 + 3*2 + 1*2,=8
12*0 + 3*2 + 1*1,=7
12*0 + 3*3 + 1*2,=11
12*0 + 3*3 + 1*1]=10
]) # 1*8
%9 = "tosa.gather"(%2, %8) : (tensor<1x12x1xf32>, tensor<1x8xi32>) -> tensor<1x8x1xf32>
%2=torch.tensor([[
[1],
[2],
[3],
[4],
[5],
[6],
[7],
[8],
[9],
[10],
[11],
[12]
]]) # 1*12*1
%9=torch.tensor([
[1],=0
[1],=0
[5],=4
[4],=3
[9],=8
[8],=7
[12],=11
[11],=10
]) # 1*8*1
%10 = "tosa.reshape"(%9) {new_shape = [1, 4, 2]} : (tensor<1x8x1xf32>) -> tensor<1x4x2xf32>
%10=torch.tensor([[
[1,1],
[5,4],
[9,8],
[12,11]
]]) # 1*4*2
%11 = torch_c.from_builtin_tensor %10 : tensor<1x4x2xf32> -> !torch.vtensor<[1,4,2],f32>
return %11 : !torch.vtensor<[1,4,2],f32>
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment