Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Last active August 16, 2023 03:27
Show Gist options
  • Save AmosLewis/48da7c4f5223ecd2b91658197397a073 to your computer and use it in GitHub Desktop.
Save AmosLewis/48da7c4f5223ecd2b91658197397a073 to your computer and use it in GitHub Desktop.
import torch
# a = torch.tensor([[0, 1, 2, 3]])
# a[..., 1:] = torch.tensor([4, 5, 6])
# = a[..., 1:4] = torch.tensor([4, 5, 6])
# = a[[0, 0, 0], [1, 2, 3]] = torch.tensor([4, 5, 6]) # tensor([[0, 4, 5,
# 6]]) = torch.ops.aten.index_put(torch.tensor([[0, 1, 2, 3]]), # input
# (torch.tensor([0, 0, 0]), torch.tensor([1, 2,
# 3])), # indicies torch.tensor([4, 5, 6])) #
# value
# = torch.ops.aten.index_put(torch.tensor([[0, 1, 2, 3]]), # input
# (None, torch.tensor([1, 2, 3]),),# indicies
# torch.tensor([4, 5, 6])) # value
b = torch.ops.aten.index_put(torch.tensor([[0, 1, 2, 3],[0, 1, 2, 3]]), # input
(None, torch.tensor([1, 2, 3]),),# indicies
torch.tensor([4, 5, 6])) # value
# print(b)
c = torch.ops.aten.index_put(torch.tensor([[0, 1, 2, 3],[0, 1, 2, 3]]), # input
(torch.tensor([[0,0,0],[1,1,1]]), torch.tensor([1, 2, 3]),),# indicies
torch.tensor([4, 5, 6])) # value
# print(c)
e = torch.ops.aten.index_put(torch.tensor([[0, 1, 2, 3],[0, 1, 2, 3]]), # input
(torch.tensor([[0],[1]]), torch.tensor([1, 2, 3]),),# indicies
torch.tensor([4, 5, 6])) # value
# print(e)
src = torch.arange(0, 5)
print(src)
index2 = torch.tensor([1, 2, 3, 4, 0])
input = torch.arange(15, 30, step=1, dtype=src.dtype).view(3, 5)
print(input)
# tensor([[15, 16, 17, 18, 19],
# [20, 21, 22, 23, 24],
# [25, 26, 27, 28, 29]])
index1 = torch.tensor([[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1],
[2, 2, 2, 2, 2]]) # 3*5
# print(index1.shape)
index1 = torch.tensor([[[0],
[1],
[2]]]) #1*3*1
# print(index1.shape)
# index1 = torch.tensor([[0],
# [1],
# [2]]) # 3*1 bug
# print(index1.shape)
d = torch.ops.aten.index_put(input, indices=(index1, index2), values=src, accumulate=False)
print(d)
# tensor([[4, 0, 1, 2, 3],
# [4, 0, 1, 2, 3],
# [4, 0, 1, 2, 3]])
input2 = torch.arange(15, 45, step=1, dtype=src.dtype).view(3, 5, 1, 2)
src = torch.arange(0, 1)
# print(input2)
index3=torch.tensor([0])
index4=torch.tensor([0,1])
g = torch.ops.aten.index_put(input2, indices=(None, index2, None, None), values=src, accumulate=False)
# print(g)
# print(g.shape)
# f = torch.ops.aten.index_put(input2, indices=(None, index2, None, index4), values=src, accumulate=False)
# print(f)
# indices=torch.tensor([[0,1,2],[1, 2, 3, 4, 0],[0],[0,1]])
# indices_1=torch.tensor([[0],[1],[2]],[[1], [2], [3], [4], [0]],[[0]],[[0],[1]])
# [3,5,1,2]
# dim0 5*1*2=10
index1 = torch.tensor([
0,0,0,0,0,0,0,0,0,0,
1,1,1,1,1,1,1,1,1,1,
2,2,2,2,2,2,2,2,2,])
# dim 0: 1*2=2
# dim 1: 3
index2 = torch.tensor([
1,1,2,2,3,3,4,4,0,0,
1,1,2,2,3,3,4,4,0,0,
1,1,2,2,3,3,4,4,0,0,
])
# dim 0: 2
# dim 1:3*5
index3 = torch.tensor([
[[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]],
[[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]],
[[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]],
])
index3 = torch.tensor([
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
])
# dim 0:1
# dim 1: 3*5*1
index4 = torch.tensor([
0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
0, 1, 0, 1, 0, 1, 0, 1, 0 ,1,
])
# g = torch.ops.aten.index_put(input2, indices=(None, index2, None, None), values=src, accumulate=False)
# print(g)
# print(g.shape)
h = torch.ops.aten.index_put(input2, indices=(index1, index2, index3, index4), values=src, accumulate=False)
print(h)
print(h.shape) # ([3, 5, 1, 2])
# None index remove algorithm
# 1 Get output shape [3, 5, 1, 2]
# 2 Get each None index by input shape [[0,1,2], [1,2,3,4,0], [0], [0,1]]
# 3 Get multiply coefficeint_after [10, 2, 2, 1]
# multiply coefficeint_before [1, 3, 15 , 15]
# 4 Get raw indices by repeating each index coefficeint times
# 5 Reshape raw indices to output shape
# 6 broadcast src/update to output shape
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment