Created
June 26, 2020 10:04
-
-
Save Voultapher/68ac02906b77b3023db443cc0be812e5 to your computer and use it in GitHub Desktop.
C++ generate nested loops based on runtime loop bounds and static known loop depth.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// https://godbolt.org/z/-Zq5A9 | |
#include <array> | |
#include <tuple> | |
#include <utility> | |
namespace detail { | |
class LoopFunc { | |
public: | |
LoopFunc(std::ptrdiff_t loop_max) : _loop_max{loop_max} {} | |
template<std::ptrdiff_t Id, typename T, typename... Ts> | |
constexpr auto call(const T& loop_funcs, Ts... ids) const -> void { | |
for (std::ptrdiff_t i = 0; i < _loop_max; ++i) { | |
std::get<Id + 1>(loop_funcs).template call<Id + 1>(loop_funcs, i, ids...); | |
} | |
} | |
private: | |
std::ptrdiff_t _loop_max; | |
}; | |
template<typename Func> | |
class WrappedUserFunction { | |
public: | |
WrappedUserFunction(Func& func) : _func{func} {} | |
template<std::ptrdiff_t Id, typename T, typename... Ts> | |
constexpr auto call(const T& /*loop_funcs*/, Ts... ids) const -> void { | |
_func(ids...); | |
} | |
private: | |
Func& _func; | |
}; | |
template<std::size_t N, std::size_t... Ids, typename Func> | |
constexpr auto for_each(Func& func, std::array<std::ptrdiff_t, N> loop_bounds, std::index_sequence<Ids...>) -> void { | |
// Reverse loop bounds but keep loop order same, to have boundN, boundN-1 ... -> func(i0, i1 ...). | |
const std::tuple loop_funcs{LoopFunc{loop_bounds[N - Ids - 1]}..., WrappedUserFunction{func}}; | |
// Start loops. | |
std::get<0>(loop_funcs).template call<0>(loop_funcs); | |
} | |
} // namespace detail | |
template<typename Func, std::size_t N> | |
constexpr auto for_each(Func&& func, std::array<std::ptrdiff_t, N> loop_bounds) -> void { | |
if constexpr (N == 0) { | |
// Implementation starts with first tuple element, need N >= 1. | |
return; | |
} | |
detail::for_each(func, loop_bounds, std::make_index_sequence<N>()); | |
} | |
// Open questions: | |
// By ref move, forward etc. What do you really want, what is the fastest to compile. | |
template<typename... Ts> | |
auto runtime_func(Ts... ids) -> void; | |
// Example if the compiler has no idea about the loop bounds, but knows how many there are. | |
auto runtime_opaque_loop_bounds(std::array<std::ptrdiff_t, 4> loop_bounds) -> void { | |
for_each([](auto... ids) { runtime_func(ids...); }, loop_bounds); | |
} | |
auto runtime_opaque_loop_bounds_native(std::array<std::ptrdiff_t, 4> loop_bounds) -> void { | |
for(std::ptrdiff_t i3 = 0; i3 < loop_bounds[3]; ++i3) { | |
for(std::ptrdiff_t i2 = 0; i2 < loop_bounds[2]; ++i2) { | |
for(std::ptrdiff_t i1 = 0; i1 < loop_bounds[1]; ++i1) { | |
for(std::ptrdiff_t i0 = 0; i0 < loop_bounds[0]; ++i0) { | |
runtime_func(i0,i1,i2,i3); | |
} | |
} | |
} | |
} | |
} | |
// Exactly the same code gen with O2 for clang, gcc is nearly the same. | |
#include <iostream> | |
template<typename Func> | |
auto sanity_ref(Func&& func, std::ptrdiff_t d0, std::ptrdiff_t d1, std::ptrdiff_t d2) -> void { | |
for(std::ptrdiff_t i2=0;i2<d2;++i2) { | |
for(std::ptrdiff_t i1=0;i1<d1;++i1) { | |
for(std::ptrdiff_t i0=0;i0<d0;++i0) { | |
func(i0,i1,i2); | |
} | |
} | |
} | |
} | |
int main() { | |
const auto print_all = [](auto... ids) { (std::cout << ... << ids) << '\n'; }; | |
//for_each<2, 1, 3>([](auto... ids) { (std::cout << ... << ids) << '\n'; }); | |
for_each(print_all, std::array{2L, 1L, 3L}); | |
//runtime_func(); // divide assembly, to compare output even with inlining. | |
std::cout << "Reference\n"; | |
sanity_ref(print_all, 2L, 1L, 3L); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment