Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save AmosLewis/9d55e5f89de1abd0c1491c8e90beadb9 to your computer and use it in GitHub Desktop.
Save AmosLewis/9d55e5f89de1abd0c1491c8e90beadb9 to your computer and use it in GitHub Desktop.
#loc = loc(unknown)
module attributes {torch.debug_module_name = "EmbeddingBagDenseBackwardModule"} {
func.func @forward(%arg0: !torch.vtensor<[3,2],f32> loc(unknown), %arg1: !torch.vtensor<[3],si64> loc(unknown), %arg2: !torch.vtensor<[3],si64> loc(unknown), %arg3: !torch.vtensor<[1],si64> loc(unknown)) -> !torch.vtensor<[2,2],f32> {
%int2 = torch.constant.int 2 loc(#loc2)
%none = torch.constant.none loc(#loc2)
%int6 = torch.constant.int 6 loc(#loc2)
%true = torch.constant.bool true loc(#loc2)
%false = torch.constant.bool false loc(#loc2)
%0 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int> loc(#loc2)
%1 = torch.aten.zeros %0, %int6, %none, %none, %none : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,2],f32> loc(#loc2)
%2 = torch.prim.ListConstruct %arg1 : (!torch.vtensor<[3],si64>) -> !torch.list<vtensor> loc(#loc2)
%3 = torch.aten._index_put_impl %1, %2, %arg0, %true, %false: !torch.vtensor<[2,2],f32>, !torch.list<vtensor>, !torch.vtensor<[3,2],f32>, !torch.bool, !torch.bool -> !torch.vtensor<[2,2],f32> loc(#loc2)
return %3 : !torch.vtensor<[2,2],f32> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc1 = loc("/home/nithin/torch-mlir/build-debug/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/test_suite/backprop.py":324:15)
#loc2 = loc("aten::_embedding_bag_dense_backward"(#loc1))
@AmosLewis
Copy link
Author

AmosLewis commented Sep 8, 2023

torch-mlir-opt --convert-torch-to-linalg /nodclouddata/chi/src/models/t5/slicecopy/EmbeddingBagDenseBackwardsModule_indexput.mlir

module attributes {torch.debug_module_name = "EmbeddingBagDenseBackwardModule"} {
  func.func @forward(%arg0: !torch.vtensor<[3,2],f32>, %arg1: !torch.vtensor<[3],si64>, %arg2: !torch.vtensor<[3],si64>, %arg3: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,2],f32> {
    %int2 = torch.constant.int 2
    %none = torch.constant.none
    %int6 = torch.constant.int 6
    %true = torch.constant.bool true
    %false = torch.constant.bool false
    %0 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int>
    %1 = torch_c.to_i64 %int2
    %2 = torch_c.to_i64 %int2
    %3 = arith.index_cast %1 : i64 to index
    %4 = arith.index_cast %2 : i64 to index
    %cst = arith.constant 0.000000e+00 : f32
    %5 = tensor.empty(%3, %4) : tensor<?x?xf32>
    %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<?x?xf32>) -> tensor<?x?xf32>
    %cast = tensor.cast %6 : tensor<?x?xf32> to tensor<2x2xf32>
    %7 = torch_c.from_builtin_tensor %cast : tensor<2x2xf32> -> !torch.vtensor<[2,2],f32>
    %8 = torch.prim.ListConstruct %arg1 : (!torch.vtensor<[3],si64>) -> !torch.list<vtensor>
    %9 = torch.aten._index_put_impl %7, %8, %arg0, %true, %false : !torch.vtensor<[2,2],f32>, !torch.list<vtensor>, !torch.vtensor<[3,2],f32>, !torch.bool, !torch.bool -> !torch.vtensor<[2,2],f32>
    return %9 : !torch.vtensor<[2,2],f32>
  }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment