Last active
August 16, 2023 03:27
-
-
Save AmosLewis/48da7c4f5223ecd2b91658197397a073 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
# a = torch.tensor([[0, 1, 2, 3]]) | |
# a[..., 1:] = torch.tensor([4, 5, 6]) | |
# = a[..., 1:4] = torch.tensor([4, 5, 6]) | |
# = a[[0, 0, 0], [1, 2, 3]] = torch.tensor([4, 5, 6]) # tensor([[0, 4, 5, | |
# 6]]) = torch.ops.aten.index_put(torch.tensor([[0, 1, 2, 3]]), # input | |
# (torch.tensor([0, 0, 0]), torch.tensor([1, 2, | |
# 3])), # indicies torch.tensor([4, 5, 6])) # | |
# value | |
# = torch.ops.aten.index_put(torch.tensor([[0, 1, 2, 3]]), # input | |
# (None, torch.tensor([1, 2, 3]),),# indicies | |
# torch.tensor([4, 5, 6])) # value | |
b = torch.ops.aten.index_put(torch.tensor([[0, 1, 2, 3],[0, 1, 2, 3]]), # input | |
(None, torch.tensor([1, 2, 3]),),# indicies | |
torch.tensor([4, 5, 6])) # value | |
# print(b) | |
c = torch.ops.aten.index_put(torch.tensor([[0, 1, 2, 3],[0, 1, 2, 3]]), # input | |
(torch.tensor([[0,0,0],[1,1,1]]), torch.tensor([1, 2, 3]),),# indicies | |
torch.tensor([4, 5, 6])) # value | |
# print(c) | |
e = torch.ops.aten.index_put(torch.tensor([[0, 1, 2, 3],[0, 1, 2, 3]]), # input | |
(torch.tensor([[0],[1]]), torch.tensor([1, 2, 3]),),# indicies | |
torch.tensor([4, 5, 6])) # value | |
# print(e) | |
src = torch.arange(0, 5) | |
print(src) | |
index2 = torch.tensor([1, 2, 3, 4, 0]) | |
input = torch.arange(15, 30, step=1, dtype=src.dtype).view(3, 5) | |
print(input) | |
# tensor([[15, 16, 17, 18, 19], | |
# [20, 21, 22, 23, 24], | |
# [25, 26, 27, 28, 29]]) | |
index1 = torch.tensor([[0, 0, 0, 0, 0], | |
[1, 1, 1, 1, 1], | |
[2, 2, 2, 2, 2]]) # 3*5 | |
# print(index1.shape) | |
index1 = torch.tensor([[[0], | |
[1], | |
[2]]]) #1*3*1 | |
# print(index1.shape) | |
# index1 = torch.tensor([[0], | |
# [1], | |
# [2]]) # 3*1 bug | |
# print(index1.shape) | |
d = torch.ops.aten.index_put(input, indices=(index1, index2), values=src, accumulate=False) | |
print(d) | |
# tensor([[4, 0, 1, 2, 3], | |
# [4, 0, 1, 2, 3], | |
# [4, 0, 1, 2, 3]]) | |
input2 = torch.arange(15, 45, step=1, dtype=src.dtype).view(3, 5, 1, 2) | |
src = torch.arange(0, 1) | |
# print(input2) | |
index3=torch.tensor([0]) | |
index4=torch.tensor([0,1]) | |
g = torch.ops.aten.index_put(input2, indices=(None, index2, None, None), values=src, accumulate=False) | |
# print(g) | |
# print(g.shape) | |
# f = torch.ops.aten.index_put(input2, indices=(None, index2, None, index4), values=src, accumulate=False) | |
# print(f) | |
# indices=torch.tensor([[0,1,2],[1, 2, 3, 4, 0],[0],[0,1]]) | |
# indices_1=torch.tensor([[0],[1],[2]],[[1], [2], [3], [4], [0]],[[0]],[[0],[1]]) | |
# [3,5,1,2] | |
# dim0 5*1*2=10 | |
index1 = torch.tensor([ | |
0,0,0,0,0,0,0,0,0,0, | |
1,1,1,1,1,1,1,1,1,1, | |
2,2,2,2,2,2,2,2,2,]) | |
# dim 0: 1*2=2 | |
# dim 1: 3 | |
index2 = torch.tensor([ | |
1,1,2,2,3,3,4,4,0,0, | |
1,1,2,2,3,3,4,4,0,0, | |
1,1,2,2,3,3,4,4,0,0, | |
]) | |
# dim 0: 2 | |
# dim 1:3*5 | |
index3 = torch.tensor([ | |
[[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]], | |
[[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]], | |
[[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]], | |
]) | |
index3 = torch.tensor([ | |
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |
]) | |
# dim 0:1 | |
# dim 1: 3*5*1 | |
index4 = torch.tensor([ | |
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, | |
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, | |
0, 1, 0, 1, 0, 1, 0, 1, 0 ,1, | |
]) | |
# g = torch.ops.aten.index_put(input2, indices=(None, index2, None, None), values=src, accumulate=False) | |
# print(g) | |
# print(g.shape) | |
h = torch.ops.aten.index_put(input2, indices=(index1, index2, index3, index4), values=src, accumulate=False) | |
print(h) | |
print(h.shape) # ([3, 5, 1, 2]) | |
# None index remove algorithm | |
# 1 Get output shape [3, 5, 1, 2] | |
# 2 Get each None index by input shape [[0,1,2], [1,2,3,4,0], [0], [0,1]] | |
# 3 Get multiply coefficeint_after [10, 2, 2, 1] | |
# multiply coefficeint_before [1, 3, 15 , 15] | |
# 4 Get raw indices by repeating each index coefficeint times | |
# 5 Reshape raw indices to output shape | |
# 6 broadcast src/update to output shape | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment