Last active
October 16, 2023 14:29
-
-
Save chenshuo/1853b5a9785bdc1526edea70e0539154 to your computer and use it in GitHub Desktop.
MultiArray
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <assert.h> | |
#include <array> | |
#include <functional> | |
#include <numeric> | |
template<typename T, int N> | |
class MultiArrayBase | |
{ | |
public: | |
MultiArrayBase(T* data, const std::array<int, N>& dim) | |
: data_(data), dim_(dim) | |
{ | |
next_ = std::accumulate(dim.begin()+1, dim.end(), 1L, std::multiplies<size_t>()); | |
} | |
MultiArrayBase<T, N-1> operator[](unsigned idx) | |
{ | |
assert(idx < static_cast<unsigned>(dim_[0])); | |
std::array<int, N-1> dim; | |
std::copy(dim_.begin()+1, dim_.end(), dim.begin()); | |
MultiArrayBase<T, N-1> result(data_ + idx * next_, dim); | |
return result; | |
} | |
T* data_; | |
size_t next_; | |
const std::array<int, N> dim_; | |
}; | |
template<typename T> | |
class MultiArrayBase<T, 1> | |
{ | |
public: | |
MultiArrayBase(T* data, const std::array<int, 1>& dim) | |
: data_(data), next_(1), dim_(dim) | |
{ | |
} | |
T& operator[](unsigned idx) | |
{ | |
assert(idx < dim_[0]); | |
return data_[idx]; | |
} | |
T* const data_; | |
const size_t next_; | |
const std::array<int, 1> dim_; | |
}; | |
template<typename T, int N> | |
class MultiArray : public MultiArrayBase<T, N> | |
{ | |
public: | |
explicit MultiArray(const std::array<int, N>& dim) | |
: MultiArrayBase<T, N>(nullptr, dim) | |
{ | |
size_t sz = this->next_ * dim[0]; | |
this->data_ = new T[sz](); | |
} | |
~MultiArray() | |
{ | |
delete[] this->data_; | |
} | |
private: | |
// FIXME: | |
MultiArray(const MultiArray&) = delete; | |
void operator=(const MultiArray&) = delete; | |
}; | |
int main() | |
{ | |
std::array<int, 3> dim3{4,5,3}; | |
MultiArray<double, 3> arr3(dim3); | |
int value = 0; | |
for (int i = 0; i < dim3[0]; ++i) | |
for (int j = 0; j < dim3[1]; ++j) | |
for (int k = 0; k < dim3[2]; ++k) | |
arr3[i][j][k] = 10.0 * ++value; | |
for (int i = 0; i < dim3[0]; ++i) | |
for (int j = 0; j < dim3[1]; ++j) | |
for (int k = 0; k < dim3[2]; ++k) | |
printf("%.1f\n", arr3[i][j][k]); | |
arr3[3][5][2]; // fail | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment