Last active
May 10, 2022 13:33
-
-
Save nihui/17a9a03e64730e6d6042d42432654dd5 to your computer and use it in GitHub Desktop.
int8 vector multiplication in loongson mmi and mips msa
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// g++ mul.cpp -o mul -mmsa -mloongson-mmi -O3 | |
// https://github.com/Tencent/ncnn/blob/master/src/layer/mips/loongson_mmi.h | |
// root@ls2k:~/ncnn/build# ./quant | |
// mul_s8x8 385.743 | |
// mul_s8x8_mmi 611.364 | |
// mul_s8x8_msa 173.241 | |
// -66 2 0 4 10 18 28 40 | |
#include <msa.h> | |
#include <stdio.h> | |
#include <stdlib.h> | |
#include "loongson_mmi.h" | |
#include <sys/time.h> | |
static double get_current_time() | |
{ | |
struct timeval tv; | |
gettimeofday(&tv, NULL); | |
return tv.tv_sec * 1000.0 + tv.tv_usec / 1000.0; | |
} | |
__attribute__((noinline)) | |
static void mul_s8x8(const signed char* vptr, const signed char* kptr, int* out) | |
{ | |
out[0] = vptr[0] * kptr[0]; | |
out[1] = vptr[1] * kptr[1]; | |
out[2] = vptr[2] * kptr[2]; | |
out[3] = vptr[3] * kptr[3]; | |
out[4] = vptr[4] * kptr[4]; | |
out[5] = vptr[5] * kptr[5]; | |
out[6] = vptr[6] * kptr[6]; | |
out[7] = vptr[7] * kptr[7]; | |
} | |
__attribute__((noinline)) | |
static void mul_s8x8_mmi(const signed char* vptr, const signed char* kptr, int* out) | |
{ | |
int8x8_t _v = __mmi_pldb_s(vptr); | |
int8x8_t _k = __mmi_pldb_s(kptr); | |
int8x8_t _zero = __mmi_pzerob_s(); | |
int8x8_t _extv = __mmi_pcmpgtb_s(_zero, _v); | |
int8x8_t _extk = __mmi_pcmpgtb_s(_zero, _k); | |
int16x4_t _v0 = (int16x4_t)__mmi_punpcklbh_s(_v, _extv); | |
int16x4_t _v1 = (int16x4_t)__mmi_punpckhbh_s(_v, _extv); | |
int16x4_t _k0 = (int16x4_t)__mmi_punpcklbh_s(_k, _extk); | |
int16x4_t _k1 = (int16x4_t)__mmi_punpckhbh_s(_k, _extk); | |
int16x4_t _s0l = __mmi_pmullh(_v0, _k0); | |
int16x4_t _s0h = __mmi_pmulhh(_v0, _k0); | |
int16x4_t _s1l = __mmi_pmullh(_v1, _k1); | |
int16x4_t _s1h = __mmi_pmulhh(_v1, _k1); | |
int32x2_t _s0 = (int32x2_t)__mmi_punpcklhw_s(_s0l, _s0h); | |
int32x2_t _s1 = (int32x2_t)__mmi_punpckhhw_s(_s0l, _s0h); | |
int32x2_t _s2 = (int32x2_t)__mmi_punpcklhw_s(_s1l, _s1h); | |
int32x2_t _s3 = (int32x2_t)__mmi_punpckhhw_s(_s1l, _s1h); | |
__mmi_pstw_s(out, _s0); | |
__mmi_pstw_s(out + 2, _s1); | |
__mmi_pstw_s(out + 4, _s2); | |
__mmi_pstw_s(out + 6, _s3); | |
} | |
__attribute__((noinline)) | |
static void mul_s8x8_msa(const signed char* vptr, const signed char* kptr, int* out) | |
{ | |
v16i8 _v = __msa_ld_b(vptr, 0); | |
v16i8 _k = __msa_ld_b(kptr, 0); | |
v8i16 _v01 = (v8i16)__msa_ilvr_b(__msa_clti_s_b(_v, 0), _v); | |
v8i16 _k01 = (v8i16)__msa_ilvr_b(__msa_clti_s_b(_k, 0), _k); | |
v8i16 _s01 = __msa_mulv_h(_v01, _k01); | |
v8i16 _exts01 = __msa_clti_s_h(_s01, 0); | |
v4i32 _s0 = (v4i32)__msa_ilvr_h(_exts01, _s01); | |
v4i32 _s1 = (v4i32)__msa_ilvl_h(_exts01, _s01); | |
__msa_st_w(_s0, out, 0); | |
__msa_st_w(_s1, out + 4, 0); | |
} | |
int main(int argc, char** argv) | |
{ | |
signed char vptr[8] = {33, -2, 1, 4, 5, 6, 7, 8}; | |
signed char kptr[8] = {-2, -1, 0, 1, 2, 3, 4, 5}; | |
if (argc == 9) | |
{ | |
vptr[0] = atoi(argv[0]); | |
vptr[1] = atoi(argv[1]); | |
vptr[2] = atoi(argv[2]); | |
vptr[3] = atoi(argv[3]); | |
vptr[4] = atoi(argv[4]); | |
vptr[5] = atoi(argv[5]); | |
vptr[6] = atoi(argv[6]); | |
vptr[7] = atoi(argv[7]); | |
kptr[0] = atoi(argv[0]); | |
kptr[1] = atoi(argv[1]); | |
kptr[2] = atoi(argv[2]); | |
kptr[3] = atoi(argv[3]); | |
kptr[4] = atoi(argv[4]); | |
kptr[5] = atoi(argv[5]); | |
kptr[6] = atoi(argv[6]); | |
kptr[7] = atoi(argv[7]); | |
} | |
int out[8]; | |
double t0 = get_current_time(); | |
for (int i = 0; i < 10000000; i++) | |
mul_s8x8(vptr, kptr, out); | |
double t1 = get_current_time(); | |
for (int i = 0; i < 10000000; i++) | |
mul_s8x8_mmi(vptr, kptr, out); | |
double t2 = get_current_time(); | |
for (int i = 0; i < 10000000; i++) | |
mul_s8x8_msa(vptr, kptr, out); | |
double t3 = get_current_time(); | |
fprintf(stderr, "mul_s8x8 %.3f\n", t1-t0); | |
fprintf(stderr, "mul_s8x8_mmi %.3f\n", t2-t1); | |
fprintf(stderr, "mul_s8x8_msa %.3f\n", t3-t2); | |
fprintf(stderr, "%d %d %d %d %d %d %d %d\n", out[0], out[1], out[2], out[3], out[4], out[5], out[6], out[7]); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment