Skip to content

Instantly share code, notes, and snippets.

@zou3519
Created October 25, 2017 17:51
Show Gist options
  • Save zou3519/4c57da40fc287fa8a8399308412a5d0d to your computer and use it in GitHub Desktop.
Save zou3519/4c57da40fc287fa8a8399308412a5d0d to your computer and use it in GitHub Desktop.
test scatter_add_ and index_add_ safety checks
import torch
def test_index_add(cuda):
a = torch.ones(5)
b = torch.ones(5)
i = torch.ones(5).long()
if cuda:
a = a.cuda()
b = b.cuda()
i = i.cuda()
a.index_add_(0, i, b) # should not errorn
print("OK")
a.index_add_(0, i, a) # should error
print("not OK")
def test_scatter_add(cuda):
a = torch.ones(5)
b = torch.ones(5)
i = torch.ones(5).long()
if cuda:
a = a.cuda()
b = b.cuda()
i = i.cuda()
a.scatter_add_(0, i, b) # should not error
print("OK")
a.scatter_add_(0, i, a) # should error
print("not OK")
test_index_add(True)
test_index_add(False)
test_scatter_add(True)
test_scatter_add(False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment