Skip to content

Instantly share code, notes, and snippets.

@yzhliu
Created August 23, 2018 00:33
Show Gist options
  • Save yzhliu/bd571f356eba1411649fdf588ef5bf30 to your computer and use it in GitHub Desktop.
Save yzhliu/bd571f356eba1411649fdf588ef5bf30 to your computer and use it in GitHub Desktop.
diff --git a/nnvm/src/top/nn/nn.cc b/nnvm/src/top/nn/nn.cc
index 322d77b6..12146cca 100644
--- a/nnvm/src/top/nn/nn.cc
+++ b/nnvm/src/top/nn/nn.cc
@@ -641,6 +641,31 @@ inline bool LayoutTransformInferShape(const NodeAttrs& attrs,
return true;
}
+inline Array<Expr> layout_transform_func(const std::string& src, const std::string& dst, const Array<Var>& dst_indices) {
+ Layout src_layout(src);
+ Layout dst_layout(dst);
+ std::vector<Expr> dst_to_src_indices;
+ for (Layout::LayoutDim src_axis : src_layout) {
+ int dst_major_pos = dst_layout.indexof(Layout::to_superdim(src_axis));
+ int dst_minor_pos = dst_layout.indexof(Layout::to_subdim(src_axis));
+ int32_t src_factor = static_cast<int32_t>(src_layout.subsizeof(src_axis));
+ int32_t dst_factor = static_cast<int32_t>(dst_layout.subsizeof(src_axis));
+
+ Expr src_index(dst_indices[dst_major_pos]);
+ if (dst_minor_pos >= 0) {
+ CHECK_GT(dst_factor, 0);
+ src_index = src_index * dst_factor + dst_indices[dst_minor_pos];
+ }
+ if (Layout::is_superdim(src_axis) && src_factor > 0) {
+ src_index = src_index / src_factor;
+ } else if (Layout::is_subdim(src_axis) && src_factor > 0) {
+ src_index = src_index % src_factor;
+ }
+ dst_to_src_indices.push_back(src_index);
+ }
+ return Array<Expr>(dst_to_src_indices);
+}
+
NNVM_REGISTER_OP(__layout_transform__)
.describe(R"code(Transform the input data layout.
@@ -686,28 +711,7 @@ the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w]
<< " to " << param.dst_layout;
return Array<Tensor> {
- topi::layout_transform(inputs[0], outputs[0]->shape, [&](const Array<Var>& dst_indices) {
- std::vector<Expr> dst_to_src_indices;
- for (Layout::LayoutDim src_axis : src_layout) {
- int dst_major_pos = dst_layout.indexof(Layout::to_superdim(src_axis));
- int dst_minor_pos = dst_layout.indexof(Layout::to_subdim(src_axis));
- int32_t src_factor = static_cast<int32_t>(src_layout.subsizeof(src_axis));
- int32_t dst_factor = static_cast<int32_t>(dst_layout.subsizeof(src_axis));
-
- Expr src_index(dst_indices[dst_major_pos]);
- if (dst_minor_pos >= 0) {
- CHECK_GT(dst_factor, 0);
- src_index = src_index * dst_factor + dst_indices[dst_minor_pos];
- }
- if (Layout::is_superdim(src_axis) && src_factor > 0) {
- src_index = src_index / src_factor;
- } else if (Layout::is_subdim(src_axis) && src_factor > 0) {
- src_index = src_index % src_factor;
- }
- dst_to_src_indices.push_back(src_index);
- }
- return Array<Expr>(dst_to_src_indices);
- })
+ topi::layout_transform<layout_transform_func>(inputs[0], param.src_layout, param.dst_layout, outputs[0]->shape)
};
})
.set_support_level(1);
diff --git a/topi/include/topi/nn.h b/topi/include/topi/nn.h
index 53b89979..af582985 100644
--- a/topi/include/topi/nn.h
+++ b/topi/include/topi/nn.h
@@ -488,15 +488,17 @@ using FLayoutIndicesTransform = std::function<Array<Expr>(const Array<Var>& indi
* \param tag output tensor tag.
* \return A tensor with shape \p dst_shape.
*/
+template<Array<Expr> (*to_src_indices)(const std::string&, const std::string&, const Array<Var>&)>
inline Tensor layout_transform(const Tensor& src,
+ const std::string& src_layout,
+ const std::string& dst_layout,
const Array<Expr>& dst_shape,
- const FLayoutIndicesTransform& to_src_indices,
const std::string name = "layout_transform",
const std::string tag = kInjective) {
auto src_shape = src->shape;
return compute(
dst_shape, [&](const Array<Var>& dst_indices) {
- return src(to_src_indices(dst_indices));
+ return src(to_src_indices(src_layout, dst_layout, dst_indices));
}, name, tag);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment