Skip to content

Instantly share code, notes, and snippets.

@heiner
Last active May 19, 2020 16:37
Show Gist options
  • Save heiner/93dbe65e2f624aed1a6ba3c1e07eec2a to your computer and use it in GitHub Desktop.
Save heiner/93dbe65e2f624aed1a6ba3c1e07eec2a to your computer and use it in GitHub Desktop.
#include "skel.h"
#include <torch/extension.h>
namespace py = pybind11;
PYBIND11_MODULE(skel, m) {
py::class_<Skel>(m, "Skel").def(py::init<>()).def("tensor", &Skel::tensor)
/**/
;
}
import torch
import skel
def main():
print("PyTorch version", torch.__version__)
s = skel.Skel()
print(s.tensor(torch.arange(10)))
if __name__ == "__main__":
main()
# CXX=c++ python3 setup.py build
# or
# CXX=c++ pip install . -vv
import sys
import setuptools
from torch.utils import cpp_extension
extra_compile_args = []
extra_link_args = []
extra_objects = []
libraries = []
if sys.platform == "darwin":
extra_compile_args += ["-stdlib=libc++", "-mmacosx-version-min=10.14"]
extra_link_args += ["-stdlib=libc++", "-mmacosx-version-min=10.14"]
elif sys.platform == "linux":
pass
skel = cpp_extension.CppExtension(
name="skel",
sources=["skel.cc", "extension.cc"],
include_dirs=cpp_extension.include_paths(),
libraries=libraries,
language="c++",
extra_compile_args=["-std=c++17"] + extra_compile_args,
extra_link_args=extra_link_args,
extra_objects=extra_objects,
)
setuptools.setup(
name="skel",
ext_modules=[skel],
cmdclass={"build_ext": cpp_extension.BuildExtension},
)
#include "skel.h"
at::Tensor Skel::tensor(at::Tensor t) const { return t + 1; }
#include <ATen/ATen.h>
class Skel {
public:
Skel() {}
at::Tensor tensor(at::Tensor t) const;
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment