Last active
May 19, 2020 16:37
-
-
Save heiner/93dbe65e2f624aed1a6ba3c1e07eec2a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "skel.h" | |
#include <torch/extension.h> | |
namespace py = pybind11; | |
PYBIND11_MODULE(skel, m) { | |
py::class_<Skel>(m, "Skel").def(py::init<>()).def("tensor", &Skel::tensor) | |
/**/ | |
; | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import skel | |
def main(): | |
print("PyTorch version", torch.__version__) | |
s = skel.Skel() | |
print(s.tensor(torch.arange(10))) | |
if __name__ == "__main__": | |
main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# CXX=c++ python3 setup.py build | |
# or | |
# CXX=c++ pip install . -vv | |
import sys | |
import setuptools | |
from torch.utils import cpp_extension | |
extra_compile_args = [] | |
extra_link_args = [] | |
extra_objects = [] | |
libraries = [] | |
if sys.platform == "darwin": | |
extra_compile_args += ["-stdlib=libc++", "-mmacosx-version-min=10.14"] | |
extra_link_args += ["-stdlib=libc++", "-mmacosx-version-min=10.14"] | |
elif sys.platform == "linux": | |
pass | |
skel = cpp_extension.CppExtension( | |
name="skel", | |
sources=["skel.cc", "extension.cc"], | |
include_dirs=cpp_extension.include_paths(), | |
libraries=libraries, | |
language="c++", | |
extra_compile_args=["-std=c++17"] + extra_compile_args, | |
extra_link_args=extra_link_args, | |
extra_objects=extra_objects, | |
) | |
setuptools.setup( | |
name="skel", | |
ext_modules=[skel], | |
cmdclass={"build_ext": cpp_extension.BuildExtension}, | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "skel.h" | |
at::Tensor Skel::tensor(at::Tensor t) const { return t + 1; } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <ATen/ATen.h> | |
class Skel { | |
public: | |
Skel() {} | |
at::Tensor tensor(at::Tensor t) const; | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment