Skip to content

Instantly share code, notes, and snippets.

@niklasschmitz
Last active September 30, 2021 13:28
Show Gist options
  • Save niklasschmitz/10ae12d82d5bd749e837e8f30bdbe076 to your computer and use it in GitHub Desktop.
Save niklasschmitz/10ae12d82d5bd749e837e8f30bdbe076 to your computer and use it in GitHub Desktop.
JAX - PyTorch dlpack conversion bug
import torch
import jax
import jax.dlpack
import torch.utils.dlpack
from jax.config import config
config.update("jax_enable_x64", True)
def jax2torch(x):
return torch.utils.dlpack.from_dlpack(jax.dlpack.to_dlpack(x))
def torch2jax(x):
return jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(x))
a = torch.arange(6).reshape(2,3).moveaxis(1,0)
b = torch2jax(a)
print("stride", a.stride())
# stride (1, 3)
print(a)
# tensor([[0, 3],
# [1, 4],
# [2, 5]])
print(b) # these do match
# [[0 3]
# [1 4]
# [2 5]]
print(a[1,0], b[1,0]) # these do not match
# tensor(1) 2
print(torch.__version__) # '1.9.0+cpu'
print(jax.__version__) # '0.2.19'
@niklasschmitz
Copy link
Author

The associated issue and its workaround are here jax-ml/jax#7657 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment