Created
June 17, 2024 09:08
-
-
Save cyb70289/61be2dc6ef6ab46bb9bb66c10dacf796 to your computer and use it in GitHub Desktop.
onednn-reorder
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/src/cpu/aarch64/jit_uni_reorder.cpp b/src/cpu/aarch64/jit_uni_reorder.cpp | |
index 5c7c6d5..dc55c69 100644 | |
--- a/src/cpu/aarch64/jit_uni_reorder.cpp | |
+++ b/src/cpu/aarch64/jit_uni_reorder.cpp | |
@@ -2680,6 +2680,55 @@ status_t jit_uni_reorder_t::pd_t::create(reorder_pd_t **reorder_pd, | |
return safe_ptr_assign(*reorder_pd, _pd.release()); | |
} | |
+#define MY_REORDER | |
+ | |
+#ifdef MY_REORDER | |
+#include <arm_neon.h> | |
+__attribute__ ((noinline)) | |
+void my_reorder(const void *in, void *out) { | |
+ // XXX: only works for 128*256 matrix, reorder by panel width = 4 floats | |
+ constexpr int n_rows = 128; | |
+ constexpr int n_cols = 256; | |
+ constexpr int lanes_per_vector = sizeof(float32x4_t) / sizeof(float); | |
+ | |
+ // adjust tile size | |
+ constexpr int tile_width = 4; | |
+ constexpr int tile_height = 16; | |
+ static_assert(n_cols % tile_width == 0 && \ | |
+ tile_width % lanes_per_vector == 0); | |
+ static_assert(n_rows % tile_height == 0 && tile_height % 2 == 0); | |
+ | |
+ const float *inv = (const float *)in; | |
+ for (int row = 0; row < n_rows; row += tile_height) { | |
+ constexpr int n = tile_width / lanes_per_vector; | |
+ float *outv[n]; | |
+ outv[0] = (float *)out + row * lanes_per_vector; | |
+ for (int col = 0; col < n_cols; col += tile_width) { | |
+ for (int i = 1; i < n; ++i) { | |
+ outv[i] = outv[i - 1] + n_rows * lanes_per_vector; | |
+ } | |
+ float32x4_t v[tile_height][n]; | |
+ for (int h = 0; h < tile_height; ++h) { | |
+ const float *tmp = inv + h * n_cols; | |
+ for (int c = 0; c < n; ++c) { | |
+ v[h][c] = vld1q_f32(tmp + c * lanes_per_vector); | |
+ } | |
+ } | |
+ for (int c = 0; c < n; ++c) { | |
+ float *tmp = outv[c]; | |
+ for (int r = 0; r < tile_height; r += 2) { | |
+ vst1q_f32(tmp + r * lanes_per_vector, v[r][c]); | |
+ vst1q_f32(tmp + (r + 1) * lanes_per_vector, v[r+1][c]); | |
+ } | |
+ } | |
+ inv += tile_width; | |
+ outv[0] += n_rows * tile_width; | |
+ } | |
+ inv += (tile_height - 1) * n_cols; | |
+ } | |
+} | |
+#endif | |
+ | |
void jit_uni_reorder_t::omp_driver_0d(int off, const char *in, char *out, | |
const float *src_scales, const float *dst_scales, int src_zp, | |
int dst_zp, int32_t *compensation_scratch) const { | |
@@ -2734,7 +2783,11 @@ void jit_uni_reorder_t::omp_driver_1d(int ithr, int nthr, int off, | |
(*kernel_)(&tail_params); | |
} else { | |
+#ifdef MY_REORDER | |
+ my_reorder(base_params.in, base_params.out); | |
+#else | |
(*kernel_)(&base_params); | |
+#endif | |
} | |
}); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment