Last active
June 3, 2023 17:06
-
-
Save PatWie/07b2962f75446250c138f686a6e3da0d to your computer and use it in GitHub Desktop.
MultiIndex
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
g++ main.cc -std=c++17 && ./a.out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifndef INDEX_H_ | |
#define INDEX_H_ | |
#include <cstddef> | |
#include <tuple> | |
namespace internal { | |
template <size_t TRank, size_t TSkip, size_t TPos, size_t TRemaining> | |
struct pitch_helper { | |
constexpr size_t call(const size_t dimensions_[TRank]) const { | |
return pitch_helper<TRank, TSkip - 1, TPos + 1, TRank - TPos - 1>().call( | |
dimensions_); | |
} | |
}; | |
template <size_t TRank, size_t TPos, size_t TRemaining> | |
struct pitch_helper<TRank, 0, TPos, TRemaining> { | |
constexpr size_t call(const size_t dimensions_[TRank]) const { | |
return dimensions_[TPos] * | |
pitch_helper<TRank, 0, TPos + 1, TRemaining - 1>().call(dimensions_); | |
} | |
}; | |
template <size_t TRank, size_t TPos> struct pitch_helper<TRank, 0, TPos, 0> { | |
constexpr size_t call(const size_t dimensions_[TRank]) const { return 1; } | |
}; | |
template <size_t TRank, size_t TRemaining, class T, class... Ts> | |
struct position_helper { | |
constexpr size_t call(const size_t dimensions_[TRank], T v, Ts... is) const { | |
return v * pitch_helper<TRank, TRank - TRemaining + 1, 0, TRank>().call( | |
dimensions_) + | |
position_helper<TRank, TRemaining - 1, Ts...>().call(dimensions_, | |
is...); | |
} | |
}; | |
template <size_t TRank, size_t TRemaining, class T> | |
struct position_helper<TRank, TRemaining, T> { | |
constexpr size_t call(const size_t dimensions_[TRank], T v) const { | |
return v; | |
} | |
}; | |
template <size_t TRank, size_t TPos, size_t TRemaining> | |
struct unflatten_helper { | |
template <class... Ts> | |
static constexpr auto call(const size_t dimensions_[TRank], | |
size_t flattenedIndex, Ts &...indices) noexcept { | |
const size_t pitch = | |
pitch_helper<TRank, 1, TPos, TRank - 1>().call(dimensions_); | |
const size_t index = flattenedIndex / pitch; | |
return unflatten_helper<TRank, TPos + 1, TRemaining - 1>::call( | |
dimensions_, flattenedIndex % pitch, indices..., index); | |
} | |
}; | |
template <size_t TRank, size_t TPos> struct unflatten_helper<TRank, TPos, 1> { | |
template <class... Ts> | |
static constexpr auto call(const size_t dimensions_[TRank], | |
size_t flattenedIndex, Ts &...indices) noexcept { | |
return std::make_tuple(indices..., flattenedIndex); | |
} | |
}; | |
}; // namespace internal | |
template <size_t TRank> struct BaseNdIndex { | |
protected: | |
size_t dimensions_[TRank]; | |
public: | |
template <class... Ts> | |
explicit constexpr inline BaseNdIndex(size_t i0, Ts... is) noexcept | |
: dimensions_{i0, is...} {} | |
/** | |
* Check whether given coordinate is in range. | |
*/ | |
template <class... Ts> | |
constexpr inline bool valid(size_t i0, Ts... is) const { | |
static_assert(size_t(1) + sizeof...(Ts) == TRank, | |
"Number of dimensions does not match rank! " | |
"YOU_MADE_A_PROGAMMING_MISTAKE"); | |
return valid_impl<0, Ts...>(i0, is...); | |
} | |
/** | |
* Return the number of axes. | |
* @return number of axes | |
*/ | |
constexpr inline size_t rank() const { return TRank; } | |
/** | |
* Return the dimension for a given axis. | |
* | |
* const size_t D = my_nd_array.template dim<1>(); | |
* | |
* @return dimension for given axis | |
*/ | |
template <size_t TAxis> constexpr inline size_t dim() const { | |
static_assert(TAxis < TRank, "axis < rank failed"); | |
return dimensions_[TAxis]; | |
} | |
/** | |
* Unflatten a flattened index and retrieve the corresponding | |
* indices for each dimension. | |
* | |
* size_t i, j, k; | |
* idx.unflatten(flattenedIndex, i, j, k); | |
* | |
* @param flattenedIndex the flattened index to unflatten | |
* @param indices references to variables to store the indices | |
*/ | |
constexpr inline auto unflatten(size_t flattenedIndex) const noexcept { | |
return internal::unflatten_helper<TRank, 0, TRank>::call(dimensions_, | |
flattenedIndex); | |
} | |
private: | |
template <size_t TNum, class... Ts> | |
constexpr inline bool valid_impl(size_t i0, Ts... is) const { | |
return (i0 < dimensions_[TNum]) && valid_impl<TNum + 1, Ts...>(is...); | |
} | |
template <size_t TNum, typename T> | |
constexpr inline bool valid_impl(T i0) const { | |
return (i0 < dimensions_[TRank - 1]); | |
} | |
protected: | |
template <class... Ts> | |
constexpr inline size_t index_(size_t i0, Ts... is) const { | |
return internal::position_helper<TRank, TRank, size_t, Ts...>().call( | |
dimensions_, i0, is...); | |
} | |
}; | |
/** | |
* Create an index object. | |
* | |
* The index object can handle various dimensions. | |
* | |
* auto idx = NdIndex<4>(B, H, W, C); | |
* auto TPos = idx(b, h, w, c); | |
* | |
* @param rank in each dimensions. | |
*/ | |
template <size_t TRank> struct NdIndex : public BaseNdIndex<TRank> { | |
public: | |
template <class... Ts> | |
explicit constexpr inline NdIndex(size_t i0, Ts... is) noexcept | |
: BaseNdIndex<TRank>(i0, is...) { | |
static_assert(size_t(1) + sizeof...(Ts) == TRank, | |
"Number of dimensions does not match rank! " | |
"YOU_MADE_A_PROGAMMING_MISTAKE"); | |
} | |
/** | |
* Get flattened index for a given position. | |
* | |
* auto idx = NdIndex<4>(10, 20, 30, 40); | |
* size_t actual = idx(1, 2, 3, 4); | |
* size_t expected = 1 * (20 * 30 * 40) + 2 * (30 * 40) + 3 * (40) + 4; | |
*/ | |
template <class... Ts> size_t inline operator()(size_t i0, Ts... is) const { | |
static_assert(size_t(1) + sizeof...(Ts) == TRank, | |
"Number of dimensions does not match rank! " | |
"YOU_MADE_A_PROGAMMING_MISTAKE"); | |
return this->index_(i0, is...); | |
} | |
/** | |
* Get dimension for a given axis. | |
* | |
* auto idx = NdIndex<4>(10, 20, 30, 40); | |
* size_t actual = idx[1]; // is 20 | |
*/ | |
template <class... Ts> size_t inline operator[](size_t i0) const { | |
return BaseNdIndex<TRank>::dimensions_[i0]; | |
} | |
}; | |
////////////// | |
#endif // INDEX_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifndef INDEX_H_ | |
#define INDEX_H_ | |
#include <cstddef> | |
namespace internal { | |
template <size_t TRank, size_t TSkip, size_t TPos, size_t TRemaining> | |
struct pitch_helper { | |
constexpr size_t call(const size_t dimensions_[TRank]) const { | |
return pitch_helper<TRank, TSkip - 1, TPos + 1, TRank - TPos - 1>().call( | |
dimensions_); | |
} | |
}; | |
template <size_t TRank, size_t TPos, size_t TRemaining> | |
struct pitch_helper<TRank, 0, TPos, TRemaining> { | |
constexpr size_t call(const size_t dimensions_[TRank]) const { | |
return dimensions_[TPos] * | |
pitch_helper<TRank, 0, TPos + 1, TRemaining - 1>().call(dimensions_); | |
} | |
}; | |
template <size_t TRank, size_t TPos> struct pitch_helper<TRank, 0, TPos, 0> { | |
constexpr size_t call(const size_t dimensions_[TRank]) const { return 1; } | |
}; | |
template <size_t TRank, size_t TRemaining, class T, class... Ts> | |
struct position_helper { | |
constexpr size_t call(const size_t dimensions_[TRank], T v, Ts... is) const { | |
return v * pitch_helper<TRank, TRank - TRemaining + 1, 0, TRank>().call( | |
dimensions_) + | |
position_helper<TRank, TRemaining - 1, Ts...>().call(dimensions_, | |
is...); | |
} | |
}; | |
template <size_t TRank, size_t TRemaining, class T> | |
struct position_helper<TRank, TRemaining, T> { | |
constexpr size_t call(const size_t dimensions_[TRank], T v) const { | |
return v; | |
} | |
}; | |
template <size_t TRank, size_t TPos, size_t TRemaining> | |
struct unflatten_helper { | |
template <class... Ts> | |
static constexpr void call(const size_t dimensions_[TRank], | |
size_t flattenedIndex, size_t &index, | |
Ts &...indices) noexcept { | |
const size_t pitch = | |
pitch_helper<TRank, 1, TPos, TRank - 1>().call(dimensions_); | |
index = flattenedIndex / pitch; | |
unflatten_helper<TRank, TPos + 1, TRemaining - 1>::call( | |
dimensions_, flattenedIndex % pitch, indices...); | |
} | |
}; | |
template <size_t TRank, size_t TPos> struct unflatten_helper<TRank, TPos, 1> { | |
template <class... Ts> | |
static constexpr void call(const size_t dimensions_[TRank], | |
size_t flattenedIndex, size_t &index, | |
Ts &...indices) noexcept { | |
index = flattenedIndex; | |
} | |
}; | |
}; // namespace internal | |
template <size_t TRank> struct BaseNdIndex { | |
protected: | |
size_t dimensions_[TRank]; | |
public: | |
template <class... Ts> | |
explicit constexpr inline BaseNdIndex(size_t i0, Ts... is) noexcept | |
: dimensions_{i0, is...} {} | |
/** | |
* Check whether given coordinate is in range. | |
*/ | |
template <class... Ts> | |
constexpr inline bool valid(size_t i0, Ts... is) const { | |
static_assert(size_t(1) + sizeof...(Ts) == TRank, | |
"Number of dimensions does not match rank! " | |
"YOU_MADE_A_PROGAMMING_MISTAKE"); | |
return valid_impl<0, Ts...>(i0, is...); | |
} | |
/** | |
* Return the number of axes. | |
* @return number of axes | |
*/ | |
constexpr inline size_t rank() const { return TRank; } | |
/** | |
* Return the dimension for a given axis. | |
* | |
* const size_t D = my_nd_array.template dim<1>(); | |
* | |
* @return dimension for given axis | |
*/ | |
template <size_t TAxis> constexpr inline size_t dim() const { | |
static_assert(TAxis < TRank, "axis < rank failed"); | |
return dimensions_[TAxis]; | |
} | |
/** | |
* Unflatten a flattened index and retrieve the corresponding | |
* indices for each dimension. | |
* | |
* size_t i, j, k; | |
* idx.unflatten(flattenedIndex, i, j, k); | |
* | |
* @param flattenedIndex the flattened index to unflatten | |
* @param indices references to variables to store the indices | |
*/ | |
template <class... Ts> | |
constexpr inline void unflatten(size_t flattenedIndex, | |
Ts &...indices) const noexcept { | |
static_assert(sizeof...(Ts) == TRank, | |
"Number of indices does not match rank! " | |
"YOU_MADE_A_PROGAMMING_MISTAKE"); | |
internal::unflatten_helper<TRank, 0, TRank>::call( | |
dimensions_, flattenedIndex, indices...); | |
} | |
private: | |
template <size_t TNum, class... Ts> | |
constexpr inline bool valid_impl(size_t i0, Ts... is) const { | |
return (i0 < dimensions_[TNum]) && valid_impl<TNum + 1, Ts...>(is...); | |
} | |
template <size_t TNum, typename T> | |
constexpr inline bool valid_impl(T i0) const { | |
return (i0 < dimensions_[TRank - 1]); | |
} | |
protected: | |
template <class... Ts> | |
constexpr inline size_t index_(size_t i0, Ts... is) const { | |
return internal::position_helper<TRank, TRank, size_t, Ts...>().call( | |
dimensions_, i0, is...); | |
} | |
}; | |
/** | |
* Create an index object. | |
* | |
* The index object can handle various dimensions. | |
* | |
* auto idx = NdIndex<4>(B, H, W, C); | |
* auto TPos = idx(b, h, w, c); | |
* | |
* @param rank in each dimensions. | |
*/ | |
template <size_t TRank> struct NdIndex : public BaseNdIndex<TRank> { | |
public: | |
template <class... Ts> | |
explicit constexpr inline NdIndex(size_t i0, Ts... is) noexcept | |
: BaseNdIndex<TRank>(i0, is...) { | |
static_assert(size_t(1) + sizeof...(Ts) == TRank, | |
"Number of dimensions does not match rank! " | |
"YOU_MADE_A_PROGAMMING_MISTAKE"); | |
} | |
/** | |
* Get flattened index for a given position. | |
* | |
* auto idx = NdIndex<4>(10, 20, 30, 40); | |
* size_t actual = idx(1, 2, 3, 4); | |
* size_t expected = 1 * (20 * 30 * 40) + 2 * (30 * 40) + 3 * (40) + 4; | |
*/ | |
template <class... Ts> size_t inline operator()(size_t i0, Ts... is) const { | |
static_assert(size_t(1) + sizeof...(Ts) == TRank, | |
"Number of dimensions does not match rank! " | |
"YOU_MADE_A_PROGAMMING_MISTAKE"); | |
return this->index_(i0, is...); | |
} | |
/** | |
* Get dimension for a given axis. | |
* | |
* auto idx = NdIndex<4>(10, 20, 30, 40); | |
* size_t actual = idx[1]; // is 20 | |
*/ | |
template <class... Ts> size_t inline operator[](size_t i0) const { | |
return BaseNdIndex<TRank>::dimensions_[i0]; | |
} | |
}; | |
////////////// | |
#endif // INDEX_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "index.h" | |
#include <cassert> | |
#include <iostream> | |
int main() { | |
size_t dim_b = 100; | |
size_t dim_h = 200; | |
size_t dim_w = 300; | |
size_t dim_c = 400; | |
size_t b = 10; | |
size_t h = 101; | |
size_t w = 13; | |
size_t c = 87; | |
auto index = NdIndex<4>(dim_b, dim_h, dim_w, dim_c); | |
size_t flattened_index = index(b, h, w, c); | |
size_t b_ = 0, h_ = 0, w_ = 0, c_ = 0; | |
index.unflatten(flattened_index, b_, h_, w_, c_); | |
// C++17 | |
auto [b_, h_, w_, c_] = index.unflatten(flattened_index); | |
assert(b == b_); | |
assert(h == h_); | |
assert(w == w_); | |
assert(c == c_); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment