Created
August 23, 2018 00:33
-
-
Save yzhliu/bd571f356eba1411649fdf588ef5bf30 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/nnvm/src/top/nn/nn.cc b/nnvm/src/top/nn/nn.cc | |
index 322d77b6..12146cca 100644 | |
--- a/nnvm/src/top/nn/nn.cc | |
+++ b/nnvm/src/top/nn/nn.cc | |
@@ -641,6 +641,31 @@ inline bool LayoutTransformInferShape(const NodeAttrs& attrs, | |
return true; | |
} | |
+inline Array<Expr> layout_transform_func(const std::string& src, const std::string& dst, const Array<Var>& dst_indices) { | |
+ Layout src_layout(src); | |
+ Layout dst_layout(dst); | |
+ std::vector<Expr> dst_to_src_indices; | |
+ for (Layout::LayoutDim src_axis : src_layout) { | |
+ int dst_major_pos = dst_layout.indexof(Layout::to_superdim(src_axis)); | |
+ int dst_minor_pos = dst_layout.indexof(Layout::to_subdim(src_axis)); | |
+ int32_t src_factor = static_cast<int32_t>(src_layout.subsizeof(src_axis)); | |
+ int32_t dst_factor = static_cast<int32_t>(dst_layout.subsizeof(src_axis)); | |
+ | |
+ Expr src_index(dst_indices[dst_major_pos]); | |
+ if (dst_minor_pos >= 0) { | |
+ CHECK_GT(dst_factor, 0); | |
+ src_index = src_index * dst_factor + dst_indices[dst_minor_pos]; | |
+ } | |
+ if (Layout::is_superdim(src_axis) && src_factor > 0) { | |
+ src_index = src_index / src_factor; | |
+ } else if (Layout::is_subdim(src_axis) && src_factor > 0) { | |
+ src_index = src_index % src_factor; | |
+ } | |
+ dst_to_src_indices.push_back(src_index); | |
+ } | |
+ return Array<Expr>(dst_to_src_indices); | |
+} | |
+ | |
NNVM_REGISTER_OP(__layout_transform__) | |
.describe(R"code(Transform the input data layout. | |
@@ -686,28 +711,7 @@ the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w] | |
<< " to " << param.dst_layout; | |
return Array<Tensor> { | |
- topi::layout_transform(inputs[0], outputs[0]->shape, [&](const Array<Var>& dst_indices) { | |
- std::vector<Expr> dst_to_src_indices; | |
- for (Layout::LayoutDim src_axis : src_layout) { | |
- int dst_major_pos = dst_layout.indexof(Layout::to_superdim(src_axis)); | |
- int dst_minor_pos = dst_layout.indexof(Layout::to_subdim(src_axis)); | |
- int32_t src_factor = static_cast<int32_t>(src_layout.subsizeof(src_axis)); | |
- int32_t dst_factor = static_cast<int32_t>(dst_layout.subsizeof(src_axis)); | |
- | |
- Expr src_index(dst_indices[dst_major_pos]); | |
- if (dst_minor_pos >= 0) { | |
- CHECK_GT(dst_factor, 0); | |
- src_index = src_index * dst_factor + dst_indices[dst_minor_pos]; | |
- } | |
- if (Layout::is_superdim(src_axis) && src_factor > 0) { | |
- src_index = src_index / src_factor; | |
- } else if (Layout::is_subdim(src_axis) && src_factor > 0) { | |
- src_index = src_index % src_factor; | |
- } | |
- dst_to_src_indices.push_back(src_index); | |
- } | |
- return Array<Expr>(dst_to_src_indices); | |
- }) | |
+ topi::layout_transform<layout_transform_func>(inputs[0], param.src_layout, param.dst_layout, outputs[0]->shape) | |
}; | |
}) | |
.set_support_level(1); | |
diff --git a/topi/include/topi/nn.h b/topi/include/topi/nn.h | |
index 53b89979..af582985 100644 | |
--- a/topi/include/topi/nn.h | |
+++ b/topi/include/topi/nn.h | |
@@ -488,15 +488,17 @@ using FLayoutIndicesTransform = std::function<Array<Expr>(const Array<Var>& indi | |
* \param tag output tensor tag. | |
* \return A tensor with shape \p dst_shape. | |
*/ | |
+template<Array<Expr> (*to_src_indices)(const std::string&, const std::string&, const Array<Var>&)> | |
inline Tensor layout_transform(const Tensor& src, | |
+ const std::string& src_layout, | |
+ const std::string& dst_layout, | |
const Array<Expr>& dst_shape, | |
- const FLayoutIndicesTransform& to_src_indices, | |
const std::string name = "layout_transform", | |
const std::string tag = kInjective) { | |
auto src_shape = src->shape; | |
return compute( | |
dst_shape, [&](const Array<Var>& dst_indices) { | |
- return src(to_src_indices(dst_indices)); | |
+ return src(to_src_indices(src_layout, dst_layout, dst_indices)); | |
}, name, tag); | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment