Last active
October 30, 2019 14:08
-
-
Save heiner/c03aa5df92408974922aea7ec499405c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* Build & run with | |
# conda create -n bugenv python=3.7 | |
# conda activate bugenv | |
# pip install 'torch==1.3.0' | |
# CXX=c++ python3 setup.py build develop | |
# python run.py | |
*/ | |
#include <torch/extension.h> | |
class Cat { | |
public: | |
at::Tensor cat(std::vector<at::Tensor> tensors, int dim) { | |
return at::cat(tensors, dim); | |
} | |
at::Tensor cat_with_catch_c10(std::vector<at::Tensor> tensors, int dim) { | |
try { | |
return cat(tensors, dim); | |
} catch (const c10::IndexError& e) { | |
std::cout << "c10::IndexError with " << e.what() << std::endl; | |
} catch (...) { | |
int status; | |
std::cout << "... exception! Let's see which kind: " | |
<< abi::__cxa_demangle( | |
abi::__cxa_current_exception_type()->name(), 0, 0, | |
&status) | |
<< std::endl; | |
} | |
return torch::zeros(0); | |
} | |
at::Tensor cat_with_catch_exception(std::vector<at::Tensor> tensors, | |
int dim) { | |
try { | |
return cat(tensors, dim); | |
} catch (const std::exception& e) { | |
std::cout << "std::exception with " << e.what() << std::endl; | |
} catch (...) { | |
int status; | |
std::cout << "... exception! Let's see which kind: " | |
<< abi::__cxa_demangle( | |
abi::__cxa_current_exception_type()->name(), 0, 0, | |
&status) | |
<< std::endl; | |
} | |
return torch::zeros(0); | |
} | |
private: | |
std::vector<torch::Tensor> results_; | |
}; | |
PYBIND11_MODULE(tensorbug, m) { | |
py::class_<Cat>(m, "Cat") | |
.def(py::init<>()) | |
.def("cat", &Cat::cat) | |
.def("cat_with_catch_c10", &Cat::cat_with_catch_c10) | |
.def("cat_with_catch_exception", &Cat::cat_with_catch_exception); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Build & run with | |
# conda create -n bugenv python=3.7 | |
# conda activate bugenv | |
# pip install 'torch==1.3.0' | |
# CXX=c++ python3 setup.py build develop | |
# python run.py 0 # or 1 or 2 | |
# Behavior tested on | |
# OS: Mac OSX 10.14.3 | |
# GCC version: Could not collect | |
# CMake version: version 3.12.2 | |
# Python version: 3.7 | |
# with: | |
# pip install 'torch==1.3.0' | |
# pip install 'torch==1.2.0' | |
# pip install 'torch==1.1.0' | |
# pip install 'torch==1.0.0' (needs s/c10::IndexError/c10::Error/) | |
# | |
# WORKS on | |
# OS: Ubuntu 18.04.2 LTS | |
# GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0 | |
# CMake version: version 3.12.2 | |
import sys | |
import torch | |
import tensorbug | |
catter = tensorbug.Cat() | |
codepath = int(sys.argv[1]) if len(sys.argv) > 1 else 2 | |
if codepath == 0: | |
## Just plain breaks: | |
catter.cat([torch.zeros(1)], 10) | |
if codepath == 1: | |
## Semi-good: Catches exception but breaks anyway. | |
## Can at least be debugged. | |
catter.cat_with_catch_c10([torch.zeros(1)], 10) | |
if codepath == 2: | |
## Worst: Does not catch std::exception. | |
## Non-portable inspection shows it's an index error. | |
catter.cat_with_catch_exception([torch.zeros(1)], 10) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Build with | |
# CXX=c++ python3 setup.py build develop | |
import setuptools | |
import sys | |
from torch.utils import cpp_extension | |
extra_compile_args = [] | |
extra_link_args = [] | |
if sys.platform == "darwin": | |
extra_compile_args += ["-stdlib=libc++", "-mmacosx-version-min=10.14"] | |
extra_link_args += ["-stdlib=libc++"] | |
tensorbug = cpp_extension.CppExtension( | |
name="tensorbug", | |
sources=["bug.cc"], | |
language="c++", | |
extra_compile_args=["-std=c++17"] + extra_compile_args, | |
extra_link_args=extra_link_args, | |
) | |
setuptools.setup( | |
name="tensorbug", | |
ext_modules=[tensorbug], | |
cmdclass={"build_ext": cpp_extension.BuildExtension}, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment