Skip to content

Instantly share code, notes, and snippets.

@PanagiotisPtr
Created December 31, 2018 17:30
Show Gist options
  • Save PanagiotisPtr/b6379e403a25e2fcd5f643c1361d505c to your computer and use it in GitHub Desktop.
Save PanagiotisPtr/b6379e403a25e2fcd5f643c1361d505c to your computer and use it in GitHub Desktop.
Simple n-dimensional array class ( Tensor ) in C++
#ifndef TENSOR_H
#define TENSOR_H
#include <vector>
#include <ostream>
static bool math_indexing = true;
template<typename T, size_t size>
class Tensor {
public:
template<typename ...Args>
Tensor(size_t first, Args&&... args){
data = std::vector<Tensor<T, size-1> >(first, Tensor<T, size-1>(args...));
}
Tensor<T, size-1> operator[](size_t idx){ return data[idx - math_indexing]; }
friend std::ostream& operator<<(std::ostream& out, const Tensor& t){
out << "[ ";
for(Tensor<T, size-1> tensor : t.data)
out << tensor << ", ";
out << "]";
return out;
}
private:
std::vector<Tensor<T, size-1> > data;
};
template<typename T>
class Tensor<T, 1> {
public:
Tensor(size_t size) : data(size) {}
T operator[](size_t idx){ return data[idx - math_indexing]; }
friend std::ostream& operator<<(std::ostream& out, const Tensor<T, 1>& t){
out << "[ ";
for(T val : t.data)
out << val << ", ";
out << "]";
return out;
}
private:
std::vector<T> data;
};
#endif
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment