Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Last active September 7, 2023 00:40
Show Gist options
  • Save AmosLewis/979d1aca948b3cd821d6edadc160f610 to your computer and use it in GitHub Desktop.
Save AmosLewis/979d1aca948b3cd821d6edadc160f610 to your computer and use it in GitHub Desktop.
import torch
import torch_mlir
class Net(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, input, index1, index2, src):
return torch.index_put(input, indices=(index1, index2), values=src, accumulate=False)
m = Net()
src = torch.arange(1, 6)
index1 = torch.tensor([0, 0, 0, 0, 0])
index2 = torch.tensor([1, 2, 3, 4, 0])
input = torch.arange(10, 25, step=1, dtype=src.dtype).view(3, 5)
m = torch_mlir.compile(m, [input, index1, index2, src], output_type="stablehlo")
print(m.operation.get_asm())
'''
module attributes {torch.debug_module_name = "Net"} {
func.func @forward(%arg0: !torch.vtensor<[3,5],si64>, %arg1: !torch.vtensor<[5],si64>, %arg2: !torch.vtensor<[5],si64>, %arg3: !torch.vtensor<[5],si64>) -> !torch.vtensor<[3,5],si64> {
%false = torch.constant.bool false
%0 = torch.prim.ListConstruct %arg1, %arg2 : (!torch.vtensor<[5],si64>, !torch.vtensor<[5],si64>) -> !torch.list<vtensor>
%1 = torch.aten.index_put.hacked_twin %arg0, %0, %arg3, %false : !torch.vtensor<[3,5],si64>, !torch.list<vtensor>, !torch.vtensor<[5],si64>, !torch.bool -> !torch.vtensor<[3,5],si64>
return %1 : !torch.vtensor<[3,5],si64>
}
}
'''
'''
module attributes {torch.debug_module_name = "Net"} {
func.func @forward(%arg0: tensor<3x5xi64>, %arg1: tensor<5xi64>, %arg2: tensor<5xi64>, %arg3: tensor<5xi64>) -> tensor<3x5xi64> {
%0 = stablehlo.reshape %arg1 : (tensor<5xi64>) -> tensor<5x1xi64>
%1 = stablehlo.reshape %arg2 : (tensor<5xi64>) -> tensor<5x1xi64>
%2 = stablehlo.concatenate %0, %1, dim = 1 : (tensor<5x1xi64>, tensor<5x1xi64>) -> tensor<5x2xi64>
%3 = stablehlo.reshape %arg3 : (tensor<5xi64>) -> tensor<5x1xi64>
%4 = stablehlo.reshape %2 : (tensor<5x2xi64>) -> tensor<5x2xi64>
%5 = "stablehlo.scatter"(%arg0, %4, %3) ({
^bb0(%arg4: tensor<i64>, %arg5: tensor<i64>):
stablehlo.return %arg5 : tensor<i64>
}) {indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>, unique_indices = false} : (tensor<3x5xi64>, tensor<5x2xi64>, tensor<5x1xi64>) -> tensor<3x5xi64>
return %5 : tensor<3x5xi64>
}
}
'''
import torch
import torch_mlir
class Net(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, input, index1, index2, src):
return torch.index_put(input, indices=(index1, index2), values=src, accumulate=False)
m = Net()
src = torch.arange(1, 6)
index1 = torch.tensor([0, 0, 0, 0, 0])
index2 = torch.tensor([1, 2, 3, 4, 0])
input = torch.arange(10, 25, step=1, dtype=src.dtype).view(3, 5)
m = torch_mlir.compile(m, [input, index1, index2, src], output_type="tosa")
print(m.operation.get_asm())
'''
module attributes {torch.debug_module_name = "Net"} {
func.func @forward(%arg0: tensor<3x5xi64>, %arg1: tensor<5xi64>, %arg2: tensor<5xi64>, %arg3: tensor<5xi64>) -> tensor<3x5xi64> {
%0 = "tosa.const"() <{value = dense<[[5, 1]]> : tensor<1x2xi32>}> : () -> tensor<1x2xi32>
%1 = "tosa.cast"(%arg1) : (tensor<5xi64>) -> tensor<5xi32>
%2 = "tosa.reshape"(%1) <{new_shape = array<i64: 5, 1>}> : (tensor<5xi32>) -> tensor<5x1xi32>
%3 = "tosa.cast"(%arg2) : (tensor<5xi64>) -> tensor<5xi32>
%4 = "tosa.reshape"(%3) <{new_shape = array<i64: 5, 1>}> : (tensor<5xi32>) -> tensor<5x1xi32>
%5 = "tosa.concat"(%2, %4) <{axis = 1 : i64}> : (tensor<5x1xi32>, tensor<5x1xi32>) -> tensor<5x2xi32>
%6 = "tosa.reshape"(%arg3) <{new_shape = array<i64: 1, 5, 1>}> : (tensor<5xi64>) -> tensor<1x5x1xi64>
%7 = "tosa.reshape"(%arg0) <{new_shape = array<i64: 1, 15, 1>}> : (tensor<3x5xi64>) -> tensor<1x15x1xi64>
%8 = "tosa.mul"(%5, %0) <{shift = 0 : i32}> : (tensor<5x2xi32>, tensor<1x2xi32>) -> tensor<5x2xi32>
%9 = "tosa.reduce_sum"(%8) <{axis = 1 : i64}> : (tensor<5x2xi32>) -> tensor<5x1xi32>
%10 = "tosa.reshape"(%9) <{new_shape = array<i64: 1, 5>}> : (tensor<5x1xi32>) -> tensor<1x5xi32>
%11 = "tosa.scatter"(%7, %10, %6) : (tensor<1x15x1xi64>, tensor<1x5xi32>, tensor<1x5x1xi64>) -> tensor<1x15x1xi64>
%12 = "tosa.reshape"(%11) <{new_shape = array<i64: 3, 5>}> : (tensor<1x15x1xi64>) -> tensor<3x5xi64>
return %12 : tensor<3x5xi64>
}
}
'''
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment