Skip to content

Instantly share code, notes, and snippets.

@quaat
Created February 27, 2013 00:00
Show Gist options
  • Save quaat/5043557 to your computer and use it in GitHub Desktop.
Save quaat/5043557 to your computer and use it in GitHub Desktop.
#include <iostream>
#include <vector>
#include <memory>
template <typename T, unsigned d1, unsigned d2 = 1, unsigned d3 = 1, unsigned d4 = 1>
struct TensorType
{
typedef typename T real;
static unsigned const rank = (d1 > 1 ? ( d2 > 1 ? (d3 > 1 ? (d4 > 1 ? 4 : 3) : 2) : 1) : 0);
static unsigned const dim1 = d1;
static unsigned const dim2 = d2;
static unsigned const dim3 = d3;
static unsigned const dim4 = d4;
static unsigned indexOf(unsigned const i) {
return i;
}
static unsigned indexOf(unsigned const i, unsigned const j) {
return j * d1 + i;
}
static unsigned indexOf(unsigned const i, unsigned const j, unsigned const k) {
return (k * d2 + j) * d1 + i;
}
static unsigned indexOf(unsigned const i, unsigned const j, unsigned const k, unsigned const l) {
return ((l * d3 + k) * d2 + j) * d1 + i;
}
};
template <typename TensorType>
class Tensor
{
typedef typename TensorType::real real;
std::shared_ptr<std::vector<real>> data;
public:
unsigned static const size = TensorType::dim4 * TensorType::dim3 * TensorType::dim2 * TensorType::dim1;
Tensor()
: data( new std::vector<real> )
{}
Tensor( real const & init_value)
: data( new std::vector<real>(init_value) )
{}
real &operator()(unsigned const i) {
return data->at(TensorType::indexOf(i));
}
real &operator()(unsigned const i, unsigned const j) {
return data->at(TensorType::indexOf(i,j));
}
real &operator()(unsigned const i, unsigned const j, unsigned const k ) {
return data->at(TensorType::indexOf(i, j, k));
}
real &operator()(unsigned const i, unsigned const j, unsigned const k, unsigned const l) {
return data->at(TensorType::indexOf(i, j, k, l));
}
};
template <typename real>
class Tensor<TensorType<real, 3, 3, 3> >
{
typedef TensorType<real, 3, 3, 3> T;
std::shared_ptr<std::vector<real>> data;
public:
unsigned static const size = 27;
Tensor()
: data( new std::vector<real> )
{}
Tensor( double const & init_value)
: data( new std::vector<real>(init_value) )
{}
real &operator()(unsigned const i, unsigned const j, unsigned const k, unsigned const l) {
return data->at(T::indexOf(i, j, k));
}
};
template <typename T>
unsigned getSize(Tensor<T> const & t)
{
return t.size;
}
template <>
unsigned getSize(Tensor<TensorType<double, 3, 3>> const & t)
{
std::cout << "The double tensor 3 by 3 huh?" << std::endl;
return t.size;
}
template <typename real>
unsigned getSize(Tensor<TensorType<real, 3, 3, 3>> const & t)
{
std::cout << "The tripple tensor 3 by 3 by 3 huh?" << std::endl;
return t.size;
}
int main(int, char**)
{
typedef TensorType<double, 3,3> Tensor33;
Tensor<Tensor33> tensor2;
std::cout << getSize(tensor2) << std::endl;
typedef TensorType<float,3,3,3> Tensor333;
Tensor<Tensor333> tensor3;
std::cout << getSize(tensor3) << std::endl;
return EXIT_SUCCESS;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment