Last active
August 7, 2018 06:59
-
-
Save wkcn/1d0151c898582541e6eb74f0162eaa87 to your computer and use it in GitHub Desktop.
DLPackTest
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import mxnet as mx | |
import numpy as np | |
import torch | |
from torch.utils import dlpack | |
def test_dlpack(): | |
for dtype in [np.float32, np.int32]: | |
for shape in [(3, 4, 5, 6), (2, 10), (15,)]: | |
a = mx.nd.random.uniform(shape = shape) | |
a_np = a.asnumpy() | |
pack = a.to_dlpack_for_read() | |
b = mx.nd.from_dlpack(pack) | |
a_copy = a.copy() | |
pack2 = a_copy.to_dlpack_for_write() | |
c = mx.nd.from_dlpack(pack2) | |
pack3 = mx.nd.to_dlpack_for_read(a) | |
d = mx.nd.from_dlpack(pack3) | |
a_copy = a.copy() | |
pack4 = mx.nd.to_dlpack_for_write(a_copy) | |
e = mx.nd.from_dlpack(pack4) | |
del a, pack, pack2, pack3, pack4 | |
b_np = b.asnumpy() | |
c_np = c.asnumpy() | |
d_np = d.asnumpy() | |
e_np = e.asnumpy() | |
mx.test_utils.assert_almost_equal(a_np, b_np) | |
mx.test_utils.assert_almost_equal(a_np, c_np) | |
mx.test_utils.assert_almost_equal(a_np, d_np) | |
mx.test_utils.assert_almost_equal(a_np, e_np) | |
def test_dlpack_torch(): | |
a = torch.tensor([1,2,3]) | |
b = dlpack.to_dlpack(a) | |
c = mx.nd.from_dlpack(b) | |
a_np = a.numpy() | |
c_np = c.asnumpy() | |
mx.test_utils.assert_almost_equal(a_np, c_np) | |
# torch doesn't allow dlpack's strides nullptr, so we don't test it. :-( | |
test_dlpack() | |
test_dlpack_torch() | |
print ("OK") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment