Created
February 27, 2013 00:00
-
-
Save quaat/5043557 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <vector> | |
#include <memory> | |
template <typename T, unsigned d1, unsigned d2 = 1, unsigned d3 = 1, unsigned d4 = 1> | |
struct TensorType | |
{ | |
typedef typename T real; | |
static unsigned const rank = (d1 > 1 ? ( d2 > 1 ? (d3 > 1 ? (d4 > 1 ? 4 : 3) : 2) : 1) : 0); | |
static unsigned const dim1 = d1; | |
static unsigned const dim2 = d2; | |
static unsigned const dim3 = d3; | |
static unsigned const dim4 = d4; | |
static unsigned indexOf(unsigned const i) { | |
return i; | |
} | |
static unsigned indexOf(unsigned const i, unsigned const j) { | |
return j * d1 + i; | |
} | |
static unsigned indexOf(unsigned const i, unsigned const j, unsigned const k) { | |
return (k * d2 + j) * d1 + i; | |
} | |
static unsigned indexOf(unsigned const i, unsigned const j, unsigned const k, unsigned const l) { | |
return ((l * d3 + k) * d2 + j) * d1 + i; | |
} | |
}; | |
template <typename TensorType> | |
class Tensor | |
{ | |
typedef typename TensorType::real real; | |
std::shared_ptr<std::vector<real>> data; | |
public: | |
unsigned static const size = TensorType::dim4 * TensorType::dim3 * TensorType::dim2 * TensorType::dim1; | |
Tensor() | |
: data( new std::vector<real> ) | |
{} | |
Tensor( real const & init_value) | |
: data( new std::vector<real>(init_value) ) | |
{} | |
real &operator()(unsigned const i) { | |
return data->at(TensorType::indexOf(i)); | |
} | |
real &operator()(unsigned const i, unsigned const j) { | |
return data->at(TensorType::indexOf(i,j)); | |
} | |
real &operator()(unsigned const i, unsigned const j, unsigned const k ) { | |
return data->at(TensorType::indexOf(i, j, k)); | |
} | |
real &operator()(unsigned const i, unsigned const j, unsigned const k, unsigned const l) { | |
return data->at(TensorType::indexOf(i, j, k, l)); | |
} | |
}; | |
template <typename real> | |
class Tensor<TensorType<real, 3, 3, 3> > | |
{ | |
typedef TensorType<real, 3, 3, 3> T; | |
std::shared_ptr<std::vector<real>> data; | |
public: | |
unsigned static const size = 27; | |
Tensor() | |
: data( new std::vector<real> ) | |
{} | |
Tensor( double const & init_value) | |
: data( new std::vector<real>(init_value) ) | |
{} | |
real &operator()(unsigned const i, unsigned const j, unsigned const k, unsigned const l) { | |
return data->at(T::indexOf(i, j, k)); | |
} | |
}; | |
template <typename T> | |
unsigned getSize(Tensor<T> const & t) | |
{ | |
return t.size; | |
} | |
template <> | |
unsigned getSize(Tensor<TensorType<double, 3, 3>> const & t) | |
{ | |
std::cout << "The double tensor 3 by 3 huh?" << std::endl; | |
return t.size; | |
} | |
template <typename real> | |
unsigned getSize(Tensor<TensorType<real, 3, 3, 3>> const & t) | |
{ | |
std::cout << "The tripple tensor 3 by 3 by 3 huh?" << std::endl; | |
return t.size; | |
} | |
int main(int, char**) | |
{ | |
typedef TensorType<double, 3,3> Tensor33; | |
Tensor<Tensor33> tensor2; | |
std::cout << getSize(tensor2) << std::endl; | |
typedef TensorType<float,3,3,3> Tensor333; | |
Tensor<Tensor333> tensor3; | |
std::cout << getSize(tensor3) << std::endl; | |
return EXIT_SUCCESS; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment