Created
April 24, 2019 13:11
-
-
Save grafi-tt/cb90b7cbfee8dffb34c62be46df1f48e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/chainerx_cc/chainerx/native/reduce.h b/chainerx_cc/chainerx/native/reduce.h | |
index 2b61319de..73db5bad2 100644 | |
--- a/chainerx_cc/chainerx/native/reduce.h | |
+++ b/chainerx_cc/chainerx/native/reduce.h | |
@@ -11,19 +11,83 @@ namespace chainerx { | |
namespace native { | |
namespace reduce_detail { | |
+constexpr int64_t ExpandLen = 8; | |
+constexpr int64_t SerialLen = 16; | |
+ | |
+template <typename In, typename ReductionImpl, int8_t InNdim, typename T, int64_t n> | |
+struct ExpandedPairwiseReduction { | |
+ T run(IndexIterator<InNdim>& it_in, int64_t& i_reduce, ReductionImpl&& impl) { | |
+ auto accum = ExpandedPairwiseReduction<In, ReductionImpl, InNdim, T, n / 2>::run(it_in, i_reduce, impl); | |
+ impl.Reduce(ExpandedPairwiseReduction<In, ReductionImpl, InNdim, T, n / 2>::run(it_in, i_reduce, impl), accum); | |
+ return accum; | |
+ } | |
+}; | |
+ | |
+template <typename In, typename ReductionImpl, int8_t InNdim, typename T> | |
+struct ExpandedPairwiseReduction<In, ReductionImpl, InNdim, 1> { | |
+ T run(IndexIterator<InNdim>& it_in, int64_t& i_reduce, ReductionImpl&& impl) { | |
+ return impl.MapIn(native_internal::StorageToDataType<const In>(arg.in[it_in++]), i_reduce++); | |
+ } | |
+}; | |
+ | |
+template <typename In, typename ReductionImpl, int8_t InNdim, typename T> | |
+T PairwiseReduction(IndexIterator<InNdim>& it_in, int64_t reduce_len, std::vector<T>& tree_accum, ReductionImpl&& impl) { | |
+ int64_t i_reduce = 0; | |
+ auto accum = impl.Identity(); | |
+ | |
+ bool first_loop = true; | |
+ while (i_reduce < reduce_len & -ExpandLen) { | |
+ if (first_loop) { | |
+ first_loop = false; | |
+ } else if (i_reduce & SerialLen * ExpandLen - 1 == 0) { | |
+ int i = 0; | |
+ int64_t i_reduce_tmp = i_reduce; | |
+ do { | |
+ impl.Reduce(tree_accum[i], accum); | |
+ tree_accum[i] = impl.Identity(); | |
+ ++i, i_reduce_tmp >>= 1; | |
+ } while (i_reduce_tmp & SerialLen * ExpandLen - 1 == 0); | |
+ tree_accum[i] = accum; | |
+ accum = impl.Identity(); | |
+ } | |
+ impl.Reduce(ExpandedPairwiseReduction<In, ReductionImpl, InNdim, T, ExpandLen>::run(it_in, i_reduce, impl), accum); | |
+ } | |
+ | |
+ while (i_reduce < reduce_len) { | |
+ impl.Reduce(impl.MapIn(native_internal::StorageToDataType<const In>(arg.in[it_in++]), i_reduce++), accum); | |
+ } | |
+ | |
+ for (T& leaf_accum : tree_accum) { | |
+ impl.Reduce(leaf_accum, accum); | |
+ leaf_accum = impl.Identity(); | |
+ } | |
+ return accum; | |
+} | |
+ | |
+inline int bits_of_index(int64_t n) { | |
+ if (n <= 0) return 0; | |
+ --n; | |
+ int64_t t; | |
+ int bits = 0; | |
+ if ((t = n >> 32)) bits += 32, n = t; | |
+ if ((t = n >> 16)) bits += 16, n = t; | |
+ if ((t = n >> 8)) bits += 8, n = t; | |
+ if ((t = n >> 4)) bits += 4, n = t; | |
+ if ((t = n >> 2)) bits += 2, n = t; | |
+ bits += static_cast<int>(n); // n is 0 or 1 | |
+ return bits; | |
+} | |
+ | |
template <typename In, typename Out, typename ReductionImpl, int8_t InNdim = kDynamicNdim, int8_t OutNdim = kDynamicNdim> | |
void ReductionKernel(ReductionKernelArg<In, Out, InNdim, OutNdim> arg, ReductionImpl&& impl) { | |
auto it_in = arg.in_indexer.It(0, arg.out_indexer.total_size()); | |
+ int64_t reduce_len = arg.out_indexer.total_size() / arg.in_indexer.total_size(); | |
+ std::vector<decltype(impl.Identity())> tree_accum(bits_of_index(reduce_len), impl.Identity()); | |
// Iterate over output dimensions | |
for (auto it_out = arg.out_indexer.It(0); it_out; ++it_out) { | |
- auto accum = impl.Identity(); | |
- | |
- int64_t i_reduce{0}; | |
- for (it_in.Restart(it_out.raw_index()); it_in; ++it_in, ++i_reduce) { | |
- impl.Reduce(impl.MapIn(native_internal::StorageToDataType<const In>(arg.in[it_in]), i_reduce), accum); | |
- } | |
- | |
+ it_in.Restart(it_out.raw_index()); | |
+ auto accum = PairwiseReduction<In, ReductionImpl, InNdim, decltype(impl.Identity())>(it_in, reduce_len, tree_accum, impl); | |
arg.out[it_out] = native_internal::DataToStorageType<Out>(impl.MapOut(accum)); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment