Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Last active September 8, 2023 04:46
Show Gist options
  • Save AmosLewis/800ca9c9cfbffd1f36a666761501ec78 to your computer and use it in GitHub Desktop.
Save AmosLewis/800ca9c9cfbffd1f36a666761501ec78 to your computer and use it in GitHub Desktop.
import torch
import torch_mlir
class Net(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, input, index1, src):
return torch.index_put(input, indices=(index1,), values=src, accumulate=False)
m = Net()
# EXAMPLE 1
src = torch.arange(0, 6).view(3,2)
print("src: ")
print(src)
input = torch.arange(1001, 1005, step=1, dtype=src.dtype).view(2, 2)
print("input: ")
print(input)
index1 = torch.tensor([0,1,1])
print("index1: ")
print(index1)
print("result: ")
print(torch.index_put(input, indices=(index1,), values=src, accumulate=False))
# src:
# tensor([[0, 1],
# [2, 3],
# [4, 5]])
# input:
# tensor([[1001, 1002],
# [1003, 1004]])
# index1:
# tensor([0, 1, 1])
# result:
# tensor([[0, 1],
# [4, 5]])
# m = torch_mlir.compile(m, [input, index1, src], output_type="stablehlo")
# print(m.operation.get_asm())
# EXAMPLE 2
src = torch.arange(0, 6).view(3,2)
print("src: ")
print(src)
input = torch.arange(1001, 1005, step=1, dtype=src.dtype).view(2, 2)
print("input: ")
print(input)
index1 = torch.tensor([0,1,1])
print("index1: ")
print(index1)
print("result: ")
print(torch.index_put(input, indices=(index1,), values=src, accumulate=True))
# src:
# tensor([[0, 1],
# [2, 3],
# [4, 5]])
# input:
# tensor([[1001, 1002],
# [1003, 1004]])
# index1:
# tensor([0, 1, 1])
# result:
# tensor([[1001, 1003],
# [1009, 1012]])
# EXAMPLE 3
src = torch.arange(0, 6).view(3,2)
print("src: ")
print(src)
input = torch.arange(1001, 1005, step=1, dtype=src.dtype).view(2, 2)
print("input: ")
print(input)
index1 = torch.tensor([0,0,0])
print("index1: ")
print(index1)
print("result: ")
print(torch.index_put(input, indices=(index1,), values=src, accumulate=False))
# src:
# tensor([[0, 1],
# [2, 3],
# [4, 5]])
# input:
# tensor([[1001, 1002],
# [1003, 1004]])
# index1:
# tensor([0, 0, 0])
# result:
# tensor([[ 4, 5],
# [1003, 1004]])
# m = torch_mlir.compile(m, [input, index1, src], output_type="stablehlo")
# print(m.operation.get_asm())
# EXAMPLE 4
src = torch.arange(0, 6).view(3,2)
print("src: ")
print(src)
input = torch.arange(1001, 1005, step=1, dtype=src.dtype).view(2, 2)
print("input: ")
print(input)
index1 = torch.tensor([0,0,0])
print("index1: ")
print(index1)
print("result: ")
print(torch.index_put(input, indices=(index1,), values=src, accumulate=True))
# src:
# tensor([[0, 1],
# [2, 3],
# [4, 5]])
# input:
# tensor([[1001, 1002],
# [1003, 1004]])
# index1:
# tensor([0, 0, 0])
# result:
# tensor([[1007, 1011],
# [1003, 1004]])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment