Created
October 27, 2017 18:00
-
-
Save killeent/2c128b4ed8ae084c3c7dfa94bbcfcaeb to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/torch/csrc/distributed/Module.cpp b/torch/csrc/distributed/Module.cpp | |
index a985509..293a4e1 100644 | |
--- a/torch/csrc/distributed/Module.cpp | |
+++ b/torch/csrc/distributed/Module.cpp | |
@@ -186,8 +186,8 @@ THDTensorDescriptor THDPModule_makeDescriptor(PyObject *obj) | |
PyObject *type = (PyObject*)Py_TYPE(obj); | |
#define REGISTER_TH_DESCRIPTOR(TYPE, REAL) \ | |
if (type == THP##TYPE##Class) \ | |
- return at::CPU(REAL).unsafeTensorFromTH(((THP##TYPE*)obj)->cdata, true); | |
- /* return THDTensorDescriptor_newFromTH##TYPE(((THP##TYPE*)obj)->cdata); */ | |
+ return THDTensorDescriptor_newFromTH##TYPE(((THP##TYPE*)obj)->cdata); | |
+ /* return at::CPU(REAL).unsafeTensorFromTH(((THP##TYPE*)obj)->cdata, true); */ | |
REGISTER_TH_DESCRIPTOR(DoubleTensor, at::kDouble); | |
REGISTER_TH_DESCRIPTOR(FloatTensor, at::kFloat); | |
REGISTER_TH_DESCRIPTOR(LongTensor, at::kLong); | |
@@ -199,12 +199,12 @@ THDTensorDescriptor THDPModule_makeDescriptor(PyObject *obj) | |
#ifdef WITH_CUDA | |
#define REGISTER_THC_DESCRIPTOR(TYPE, REAL) \ | |
if (type == THCP##TYPE##Class) \ | |
- return at::CUDA(REAL).unsafeTensorFromTH(((THP##TYPE*)obj)->cdata, true); | |
- /* return THDTensorDescriptor_newFromTHCuda##TYPE((THCuda##TYPE*)(((torch::THPVoidTensor*)obj)->cdata)); */ | |
+ return THDTensorDescriptor_newFromTHCuda##TYPE((THCuda##TYPE*)(((torch::THPVoidTensor*)obj)->cdata)); | |
+ /* return at::CUDA(REAL).unsafeTensorFromTH(((THP##TYPE*)obj)->cdata, true); */ | |
REGISTER_THC_DESCRIPTOR(DoubleTensor, at::kDouble); | |
if (type == THCPFloatTensorClass) | |
- return at::CUDA(at::kFloat).unsafeTensorFromTH((THCudaTensor*)(((torch::THPVoidTensor*)obj)->cdata), true); | |
- /* return THDTensorDescriptor_newFromTHCudaFloatTensor((THCudaTensor*)(((torch::THPVoidTensor*)obj)->cdata)); */ | |
+ return THDTensorDescriptor_newFromTHCudaFloatTensor((THCudaTensor*)(((torch::THPVoidTensor*)obj)->cdata)); | |
+ /* return at::CUDA(at::kFloat).unsafeTensorFromTH((THCudaTensor*)(((torch::THPVoidTensor*)obj)->cdata), true); */ | |
REGISTER_THC_DESCRIPTOR(LongTensor, at::kLong); | |
REGISTER_THC_DESCRIPTOR(IntTensor, at::kInt); | |
REGISTER_THC_DESCRIPTOR(ShortTensor, at::kShort); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment