Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Created September 7, 2023 17:57
Show Gist options
  • Save AmosLewis/943b2b61eaa687c46c8ce473cfa7556b to your computer and use it in GitHub Desktop.
Save AmosLewis/943b2b61eaa687c46c8ce473cfa7556b to your computer and use it in GitHub Desktop.
# First 3 cases the index2 is torch.Size([3])
# Case 1
input = torch.tensor([[0, 1, 2, 3]])
index1 = torch.tensor([[0]])
index2 = torch.tensor([1,2,3])
update = torch.tensor([4, 5, 6])
output = torch.ops.aten.index_put.hacked_twin(input, (index1, index2), update)
print("index1.shape: ", index1.shape) # torch.Size([1, 1])
print("index2.shape: ", index2.shape) # torch.Size([3])
print(output) # tensor([[0, 4, 5, 6]])
# Case 2
input = torch.tensor([[0, 1, 2, 3]])
index1 = torch.tensor([[0,0,0]])
index2 = torch.tensor([1,2,3])
update = torch.tensor([4, 5, 6])
output = torch.ops.aten.index_put.hacked_twin(input, (index1, index2), update)
print("index1.shape: ", index1.shape) # torch.Size([1, 3])
print("index2.shape: ", index2.shape) # torch.Size([3])
print(output) # tensor([[0, 4, 5, 6]])
# Case 3
input = torch.tensor([[0, 1, 2, 3]])
index1 = torch.tensor([0,0,0])
index2 = torch.tensor([1,2,3])
update = torch.tensor([4, 5, 6])
output = torch.ops.aten.index_put.hacked_twin(input, (index1, index2), update)
print("index1.shape: ", index1.shape) # torch.Size([1, 3])
print("index2.shape: ", index2.shape) # torch.Size([3])
print(output) # tensor([[0, 4, 5, 6]])
# Next 3 cases Change the index2 into torch.Size([1, 3])
# Case 4
input = torch.tensor([[0, 1, 2, 3]])
index1 = torch.tensor([[0]]) # torch.Size([1,1])
index2 = torch.tensor([[1,2,3]]) # torch.Size([1, 3])
update = torch.tensor([4, 5, 6])
output = torch.ops.aten.index_put.hacked_twin(input, (index1, index2), update)
print("index1.shape: ", index1.shape) # torch.Size([1, 1])
print("index2.shape: ", index2.shape) # torch.Size([1, 3])
print(output) # tensor([[0, 4, 5, 6]])
# Case 5
input = torch.tensor([[0, 1, 2, 3]])
index1 = torch.tensor([[0,0,0]]) # torch.Size([1, 3])
index2 = torch.tensor([[1,2,3]]) # torch.Size([1, 3])
update = torch.tensor([4, 5, 6])
output = torch.ops.aten.index_put.hacked_twin(input, (index1, index2), update)
print("index1.shape: ", index1.shape) # torch.Size([1, 3])
print("index2.shape: ", index2.shape) # torch.Size([1, 3])
print(output) # tensor([[0, 4, 5, 6]])
#Case 6
input = torch.tensor([[0, 1, 2, 3]])
index1 = torch.tensor([0,0,0]) # torch.Size([3])
index2 = torch.tensor([[1,2,3]]) # torch.Size([1, 3])
update = torch.tensor([4, 5, 6])
output = torch.ops.aten.index_put.hacked_twin(input, (index1, index2), update)
print("index1.shape: ", index1.shape) # torch.Size([3])
print("index2.shape: ", index2.shape) # torch.Size([1, 3])
print(output) # tensor([[0, 4, 5, 6]])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment