This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module { | |
func.func @main(%arg0: !torch.vtensor<[64],f32>, %arg1: !torch.vtensor<[64],f32>, %arg2: !torch.vtensor<[64],f32>, %arg3: !torch.vtensor<[64],f32>, %arg4: !torch.vtensor<[64],f32>, %arg5: !torch.vtensor<[16,64],f32>, %arg6: !torch.vtensor<[64,64],f32>, %arg7: !torch.vtensor<[64,64],f32>, %arg8: !torch.vtensor<[64,64],f32>, %arg9: !torch.vtensor<[64,64],f32>, %arg10: !torch.vtensor<[256,64],f32>, %arg11: !torch.vtensor<[256,64],f32>, %arg12: !torch.vtensor<[64,256],f32>, %arg13: !torch.vtensor<[64,64],f32>, %arg14: !torch.vtensor<[64,64],f32>, %arg15: !torch.vtensor<[64,64],f32>, %arg16: !torch.vtensor<[64,64],f32>, %arg17: !torch.vtensor<[256,64],f32>, %arg18: !torch.vtensor<[256,64],f32>, %arg19: !torch.vtensor<[64,256],f32>, %arg20: !torch.vtensor<[16,64],f32>, %arg21: !torch.vtensor<[4096,8],complex<f32>>, %arg22: !torch.vtensor<[32,2048,4,16],f32>, %arg23: !torch.vtensor<[32,2048,4,16],f32>, %arg24: !torch.vtensor<[32,2048,4,16],f32>, %arg25: !torch.vtensor<[32,2048,4,16],f32>, %arg26: !torch.v |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
func.func private @forward(%arg0: !torch.vtensor<[20,100,35,45],f32>) -> !torch.vtensor<[20,100,35,45],f32> { | |
%int0 = torch.constant.int 0 | |
%0 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int> | |
%int0_0 = torch.constant.int 0 | |
%int0_1 = torch.constant.int 0 | |
%cpu = torch.constant.device "cpu" | |
%none = torch.constant.none | |
%none_2 = torch.constant.none | |
%1 = torch.aten.aten.empty.memory_format %0, %int0_0, %int0_1, %cpu, %none, %none_2 : !torch.list<int>, !torch.int, !torch.int, !torch.Device, !torch.none, !torch.none -> !torch.vtensor<[0],ui8> | |
return %arg0 : !torch.vtensor<[20,100,35,45],f32> |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
CUDA HAL Target: | |
--iree-hal-cuda-dump-ptx - Dump ptx to the debug stream. | |
--iree-hal-cuda-llvm-target-arch=<string> - LLVM target chip. | |
--iree-hal-cuda-llvm-target-feature=<string> - Use to set PTX version. | |
--iree-hal-cuda-use-ptxas - It uses the ptxas compiler that is on the environment, compiles the generated PTX code with it, puts the cubin binary generated by ptxas into the executable. '--iree-hal-cuda-llvm-target-arch' is used as the target GPU. If passing additional parameters to ptxas is desired, the parameters flag can be used (e.g.'--iree-hal-cuda-use-ptxas-params=-v'). | |
--iree-hal-cuda-use-ptxas-from=<string> - It uses the provided ptxas compiler, compiles the generated PTX code with it, puts the cubin binary generated by ptxas into the executable. '--iree-hal-cuda-llvm-target-arch' is used as th |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import shark_turbine.aot as aot | |
import torch | |
import torch.nn as nn | |
class ExMod(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.m = nn.BatchNorm2d(100) | |
def forward(self,x): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#loc = loc(unknown) | |
module attributes {torch.debug_module_name = "EmbeddingBagDenseBackwardModule"} { | |
func.func @forward(%arg0: !torch.vtensor<[3,2],f32> loc(unknown), %arg1: !torch.vtensor<[3],si64> loc(unknown), %arg2: !torch.vtensor<[3],si64> loc(unknown), %arg3: !torch.vtensor<[1],si64> loc(unknown)) -> !torch.vtensor<[2,2],f32> { | |
%int2 = torch.constant.int 2 loc(#loc2) | |
%none = torch.constant.none loc(#loc2) | |
%int6 = torch.constant.int 6 loc(#loc2) | |
%true = torch.constant.bool true loc(#loc2) | |
%false = torch.constant.bool false loc(#loc2) | |
%0 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<int> loc(#loc2) | |
%1 = torch.aten.zeros %0, %int6, %none, %none, %none : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,2],f32> loc(#loc2) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch_mlir | |
class Net(torch.nn.Module): | |
def __init__(self) -> None: | |
super().__init__() | |
def forward(self, input, index1, src): | |
return torch.index_put(input, indices=(index1,), values=src, accumulate=False) | |
m = Net() | |
# EXAMPLE 1 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# First 3 cases the index2 is torch.Size([3]) | |
# Case 1 | |
input = torch.tensor([[0, 1, 2, 3]]) | |
index1 = torch.tensor([[0]]) | |
index2 = torch.tensor([1,2,3]) | |
update = torch.tensor([4, 5, 6]) | |
output = torch.ops.aten.index_put.hacked_twin(input, (index1, index2), update) | |
print("index1.shape: ", index1.shape) # torch.Size([1, 1]) | |
print("index2.shape: ", index2.shape) # torch.Size([3]) | |
print(output) # tensor([[0, 4, 5, 6]]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch_mlir | |
class Net(torch.nn.Module): | |
def __init__(self) -> None: | |
super().__init__() | |
def forward(self, input, index1, index2, src): | |
return torch.index_put(input, indices=(index1, index2), values=src, accumulate=False) | |
m = Net() | |
src = torch.arange(1, 6) | |
index1 = torch.tensor([0, 0, 0, 0, 0]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(mlir_venv) ➜ torch-mlir git:(decompose) ✗ torch-mlir-opt --convert-torch-to-stablehlo /nodclouddata/chi/src/models/t5/slicecopy/test_indexputhackedtwin.mlir --debug | |
Args: torch-mlir-opt --convert-torch-to-stablehlo /nodclouddata/chi/src/models/t5/slicecopy/test_indexputhackedtwin.mlir --debug | |
Load new dialect in Context builtin | |
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ShapedType) | |
ImplicitTypeIDRegistry::lookupOrInsert(mlir::MemRefLayoutAttrInterface) | |
ImplicitTypeIDRegistry::lookupOrInsert(mlir::TypedAttr) | |
ImplicitTypeIDRegistry::lookupOrInsert(mlir::ElementsAttr) | |
ImplicitTypeIDRegistry::lookupOrInsert(mlir::DistinctAttr) | |
ImplicitTypeIDRegistry::lookupOrInsert(mlir::BytecodeOpInterface) | |
ImplicitTypeIDRegistry::lookupOrInsert(mlir::SymbolOpInterface) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
# a = torch.tensor([[0, 1, 2, 3]]) | |
# a[..., 1:] = torch.tensor([4, 5, 6]) | |
# = a[..., 1:4] = torch.tensor([4, 5, 6]) | |
# = a[[0, 0, 0], [1, 2, 3]] = torch.tensor([4, 5, 6]) # tensor([[0, 4, 5, | |
# 6]]) = torch.ops.aten.index_put(torch.tensor([[0, 1, 2, 3]]), # input | |
# (torch.tensor([0, 0, 0]), torch.tensor([1, 2, | |
# 3])), # indicies torch.tensor([4, 5, 6])) # | |
# value | |
# = torch.ops.aten.index_put(torch.tensor([[0, 1, 2, 3]]), # input |