Skip to content

Instantly share code, notes, and snippets.

@cyb70289
Created June 17, 2024 09:08
Show Gist options
  • Save cyb70289/61be2dc6ef6ab46bb9bb66c10dacf796 to your computer and use it in GitHub Desktop.
Save cyb70289/61be2dc6ef6ab46bb9bb66c10dacf796 to your computer and use it in GitHub Desktop.
onednn-reorder
diff --git a/src/cpu/aarch64/jit_uni_reorder.cpp b/src/cpu/aarch64/jit_uni_reorder.cpp
index 5c7c6d5..dc55c69 100644
--- a/src/cpu/aarch64/jit_uni_reorder.cpp
+++ b/src/cpu/aarch64/jit_uni_reorder.cpp
@@ -2680,6 +2680,55 @@ status_t jit_uni_reorder_t::pd_t::create(reorder_pd_t **reorder_pd,
return safe_ptr_assign(*reorder_pd, _pd.release());
}
+#define MY_REORDER
+
+#ifdef MY_REORDER
+#include <arm_neon.h>
+__attribute__ ((noinline))
+void my_reorder(const void *in, void *out) {
+ // XXX: only works for 128*256 matrix, reorder by panel width = 4 floats
+ constexpr int n_rows = 128;
+ constexpr int n_cols = 256;
+ constexpr int lanes_per_vector = sizeof(float32x4_t) / sizeof(float);
+
+ // adjust tile size
+ constexpr int tile_width = 4;
+ constexpr int tile_height = 16;
+ static_assert(n_cols % tile_width == 0 && \
+ tile_width % lanes_per_vector == 0);
+ static_assert(n_rows % tile_height == 0 && tile_height % 2 == 0);
+
+ const float *inv = (const float *)in;
+ for (int row = 0; row < n_rows; row += tile_height) {
+ constexpr int n = tile_width / lanes_per_vector;
+ float *outv[n];
+ outv[0] = (float *)out + row * lanes_per_vector;
+ for (int col = 0; col < n_cols; col += tile_width) {
+ for (int i = 1; i < n; ++i) {
+ outv[i] = outv[i - 1] + n_rows * lanes_per_vector;
+ }
+ float32x4_t v[tile_height][n];
+ for (int h = 0; h < tile_height; ++h) {
+ const float *tmp = inv + h * n_cols;
+ for (int c = 0; c < n; ++c) {
+ v[h][c] = vld1q_f32(tmp + c * lanes_per_vector);
+ }
+ }
+ for (int c = 0; c < n; ++c) {
+ float *tmp = outv[c];
+ for (int r = 0; r < tile_height; r += 2) {
+ vst1q_f32(tmp + r * lanes_per_vector, v[r][c]);
+ vst1q_f32(tmp + (r + 1) * lanes_per_vector, v[r+1][c]);
+ }
+ }
+ inv += tile_width;
+ outv[0] += n_rows * tile_width;
+ }
+ inv += (tile_height - 1) * n_cols;
+ }
+}
+#endif
+
void jit_uni_reorder_t::omp_driver_0d(int off, const char *in, char *out,
const float *src_scales, const float *dst_scales, int src_zp,
int dst_zp, int32_t *compensation_scratch) const {
@@ -2734,7 +2783,11 @@ void jit_uni_reorder_t::omp_driver_1d(int ithr, int nthr, int off,
(*kernel_)(&tail_params);
} else {
+#ifdef MY_REORDER
+ my_reorder(base_params.in, base_params.out);
+#else
(*kernel_)(&base_params);
+#endif
}
});
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment