Created
October 19, 2018 12:44
-
-
Save iacolippo/c278680717fb4622f958e37caa721fe3 to your computer and use it in GitHub Desktop.
RuntimeError: variable impl does not have is_contiguous Pytorch C++ extension
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <torch/extension.h> | |
#include <cmath> | |
#include <iostream> | |
#include <vector> | |
at::Tensor ex_forward( | |
at::Tensor input | |
) { | |
auto n_samples = input.size(0); | |
auto n_features = input.size(1); | |
auto G = n_features / 2; | |
auto M = 2; | |
at::Tensor temp = at::zeros({n_samples, G, 2}); | |
at::Tensor slice1 = input.slice(1, 0, n_features, 2) + input.slice(1, 1, n_features, 2); | |
at::Tensor slice2 = input.slice(1, 0, n_features, 2) - input.slice(1, 1, n_features, 2); | |
temp = at::stack({slice1, slice2}, 2); | |
auto res = temp; | |
for (auto dumb_idx = 0; dumb_idx < std::log2(n_features) + 1; dumb_idx++) { | |
temp = at::zeros({n_samples, G / 2, M * 2}); | |
slice1 = res.slice(2, 0, M, 2).slice(1, 0, G, 2); | |
slice2 = res.slice(2, 0, M, 2).slice(1, 1, G, 2); | |
auto mesh1 = at::meshgrid({at::_cast_Long(at::arange(0, n_samples, 1)), at::_cast_Long(at::arange(0, G/2, 1)), at::_cast_Long(at::arange(0, 2 * M, 4))}); | |
auto mesh2 = at::meshgrid({at::_cast_Long(at::arange(0, n_samples, 1)), at::_cast_Long(at::arange(0, G/2, 1)), at::_cast_Long(at::arange(1, 2 * M, 4))}); | |
temp.index_put_(mesh1, slice1 + slice2); | |
temp.index_put_(mesh2, slice1 - slice2); | |
slice1 = res.slice(2, 1, M, 2).slice(1, 0, G, 2); | |
slice2 = res.slice(2, 1, M, 2).slice(1, 1, G, 2); | |
mesh1 = at::meshgrid({at::_cast_Long(at::arange(0, n_samples, 1)), at::_cast_Long(at::arange(0, G/2, 1)), at::_cast_Long(at::arange(2, 2 * M, 4))}); | |
mesh2 = at::meshgrid({at::_cast_Long(at::arange(0, n_samples, 1)), at::_cast_Long(at::arange(0, G/2, 1)), at::_cast_Long(at::arange(3, 2 * M, 4))}); | |
temp.index_put_(mesh1, slice1 - slice2); | |
temp.index_put_(mesh2, slice1 + slice2); | |
res = temp; | |
G = G / 2; | |
M = M * 2; | |
} | |
at::Tensor output = temp.select(1, 0); // select index 0 along dim 1 | |
return output * pow(std::sqrt(n_features), -1); | |
} | |
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |
m.def("forward", &ex_forward, "EX forward"); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from setuptools import setup | |
from torch.utils.cpp_extension import CppExtension, BuildExtension | |
setup(name='ex', | |
ext_modules=[CppExtension('ex', ['ex.cpp'], extra_compile_args=['-g', '-O0'])], | |
cmdclass={'build_ext': BuildExtension} | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import ex | |
class EXATen(torch.autograd.Function): | |
@staticmethod | |
def forward(ctx, input): | |
outputs = ex.forward(input) | |
return outputs | |
x = torch.randn(3, 8) | |
y1 = EXATen.apply(x) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment