Created
October 25, 2017 17:51
-
-
Save zou3519/4c57da40fc287fa8a8399308412a5d0d to your computer and use it in GitHub Desktop.
test scatter_add_ and index_add_ safety checks
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
def test_index_add(cuda): | |
a = torch.ones(5) | |
b = torch.ones(5) | |
i = torch.ones(5).long() | |
if cuda: | |
a = a.cuda() | |
b = b.cuda() | |
i = i.cuda() | |
a.index_add_(0, i, b) # should not errorn | |
print("OK") | |
a.index_add_(0, i, a) # should error | |
print("not OK") | |
def test_scatter_add(cuda): | |
a = torch.ones(5) | |
b = torch.ones(5) | |
i = torch.ones(5).long() | |
if cuda: | |
a = a.cuda() | |
b = b.cuda() | |
i = i.cuda() | |
a.scatter_add_(0, i, b) # should not error | |
print("OK") | |
a.scatter_add_(0, i, a) # should error | |
print("not OK") | |
test_index_add(True) | |
test_index_add(False) | |
test_scatter_add(True) | |
test_scatter_add(False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment