Last active
November 19, 2016 11:33
-
-
Save ibab/03a024918955325d9ba63dce265f8870 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <array> | |
#include <iostream> | |
#include <vector> | |
template <typename T, | |
// Number of dimensions of tensor. | |
int D> | |
class Tensor { | |
public: | |
template <typename... T2, | |
// Number of sizes must be equal to number of dimensions. | |
typename std::enable_if<sizeof...(T2) == D, int>::type = 0> | |
Tensor(T2... args) { | |
std::array<int, D> dims{args...}; | |
int prod = 1; | |
for (int i = 0; i < dims.size(); i++) { | |
dims_[i] = dims[i]; | |
strides_[i] = prod; | |
prod *= dims[i]; | |
} | |
data_ = std::vector<T>(prod, T(0)); | |
} | |
// Index into the tensor by calling tensor(i, j, k). | |
template <typename... T2, | |
// Number of indices must be equal to number of dimensions. | |
typename std::enable_if<sizeof...(T2) == D, int>::type = 0> | |
T& operator()(T2... args) { | |
std::array<int, D> idxs{args...}; | |
int idx = 0; | |
for (int i = 0; i < D; i++) { | |
idx += strides_[i] * idxs[i]; | |
} | |
return data_[idx]; | |
} | |
std::array<int, D> sizes() { | |
return dims_; | |
} | |
std::vector<T>& data() { | |
return data_; | |
} | |
private: | |
std::vector<T> data_; | |
std::array<int, D> dims_; | |
std::array<int, D> strides_; | |
}; | |
template<typename T, typename... T2> | |
auto ones(T2... args) { | |
constexpr int D = sizeof...(T2); | |
Tensor<T, D> t(args...); | |
for (auto& x: t.data()) { | |
x = T(1); | |
} | |
return t; | |
} | |
int main() { | |
Tensor<double, 3> a(2, 2, 2); | |
auto b = ones<double>(2, 2, 2); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment