Skip to content

Instantly share code, notes, and snippets.

@kaityo256
Last active June 6, 2019 09:54
Show Gist options
  • Save kaityo256/c5e7a02eef60e98fe8b5b08638476825 to your computer and use it in GitHub Desktop.
Save kaityo256/c5e7a02eef60e98fe8b5b08638476825 to your computer and use it in GitHub Desktop.
pack data using AVX-512
/*
# Copyright H. Watanabe 2017
# Distributed under the Boost Software License, Version 1.0.
# (See accompanying file LICENSE_1_0.txt or copy at
# http://www.boost.org/LICENSE_1_0.txt)
*/
//------------------------------------------------------------------------
#include <x86intrin.h>
#include <immintrin.h>
#include <iostream>
#include <random>
#include <chrono>
#include <stdint.h>
//------------------------------------------------------------------------
const int SIZE = 131072;
//const int SIZE = 32;
__attribute__((aligned(32))) int data[SIZE];
__attribute__((aligned(32))) int result[SIZE] = {};
__attribute__((aligned(32))) int result2[SIZE] = {};
__attribute__((aligned(32))) int result3[SIZE] = {};
__attribute__((aligned(32))) int result4[SIZE] = {};
int offset8[256 * 8] = {};
int offset16[256 * 256 * 16] = {};
//------------------------------------------------------------------------
void
print_512i(__m512i x) {
int *v = (int*)(&x);
for (int i = 0; i < 16; i++) {
printf("%d ", v[i]);
}
printf("\n");
}
//------------------------------------------------------------------------
void
print_mask(__mmask16 m) {
for (int i = 0; i < 16; i++) {
if (m & (1 << i)) {
printf("1");
} else {
printf("0");
}
}
printf("\n");
}
//------------------------------------------------------------------------
void
make_offset8(void) {
for (int v = 0; v < 256; v++) {
int n = 0;
for (int i = 0; i < 8; i++) {
if (v & (1 << i)) {
offset8[8 * v + n] = i;
n++;
}
}
}
}
//------------------------------------------------------------------------
void
make_offset16(void) {
for (int v = 0; v < 256 * 256; v++) {
int n = 0;
for (int i = 0; i < 16; i++) {
if (v & (1 << i)) {
offset16[16 * v + n] = i;
n++;
}
}
}
}
//------------------------------------------------------------------------
void
pack_512(void) {
int pos = 0;
for (int i = 0; i < SIZE / 16; i++) {
__m512i vdata = _mm512_loadu_si512(data + i * 16);
__mmask16 vmask = _mm512_test_epi32_mask(vdata, vdata);
__m512i voffset = _mm512_load_si512((__m512i const *)(offset16 + vmask * 16));
__m512i vout = _mm512_permutevar_epi32(voffset, vdata);
_mm512_store_si512((__m512i *)(result2 + pos), vout);
pos += _mm_popcnt_u32(vmask);
}
for (int i = pos; i < pos + 16 && i < SIZE; i++) {
result2[i] = 0;
}
}
//------------------------------------------------------------------------
void
pack_512c(void) {
int pos = 0;
for (int i = 0; i < SIZE / 16; i++) {
__m512i vdata = _mm512_loadu_si512(data + i * 16);
__mmask16 vmask = _mm512_test_epi32_mask(vdata, vdata);
_mm512_mask_compressstoreu_epi32(result4 + pos, vmask, vdata);
pos += _mm_popcnt_u32(vmask);
}
}
//------------------------------------------------------------------------
void
pack_256(void) {
int pos = 0;
for (int i = 0; i < SIZE / 16; i++) {
__m512i vdata = _mm512_loadu_si512(data + i * 16);
__mmask16 vmask = _mm512_test_epi32_mask(vdata, vdata);
__m256i vlow = _mm512_castsi512_si256(vdata);
__m256i vhigh = _mm512_extracti64x4_epi64(vdata, 1);
int mask_low = vmask & 255;
int mask_high = vmask >> 8;
__m256i voffset = _mm256_load_si256((__m256i const *)(offset8 + mask_low * 8));
__m256i vout_low = _mm256_permutevar8x32_epi32(vlow, voffset);
voffset = _mm256_load_si256((__m256i const *)(offset8 + mask_high * 8));
__m256i vout_high = _mm256_permutevar8x32_epi32(vhigh, voffset);
_mm256_store_si256((__m256i *)(result3 + pos), vout_low);
pos += _mm_popcnt_u32(mask_low);
_mm256_store_si256((__m256i *)(result3 + pos), vout_high);
pos += _mm_popcnt_u32(mask_high);
}
for (int i = pos; i < pos + 16 && i < SIZE; i++) {
result3[i] = 0;
}
}
//------------------------------------------------------------------------
void
pack_serial(void) {
int pos = 0;
for (int i = 0; i < SIZE; i++) {
if (data[i] != 0) {
result[pos] = data[i];
pos++;
}
}
}
//------------------------------------------------------------------------
void
dump(int *v) {
for (int i = 0; i < SIZE; i++) {
std::cout << v[i] << " ";
}
std::cout << std::endl;
}
//------------------------------------------------------------------------
bool
check(void) {
for (int i = 0; i < SIZE; i++) {
if (result[i] != result2[i])return false;
if (result[i] != result3[i])return false;
if (result[i] != result4[i])return false;
}
return true;
}
//------------------------------------------------------------------------
void
measure(void(*pfunc)(), std::string name) {
auto start = std::chrono::system_clock::now();
for (int i = 0; i < 1000; i++) {
pfunc();
}
auto end = std::chrono::system_clock::now();
std::cout << name << " " << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "[ms]" << std::endl;
}
//------------------------------------------------------------------------
int
main(void) {
make_offset8();
make_offset16();
std::mt19937 mt(1);
std::uniform_int_distribution<int> ud(0, 1);
int n = 0;
for (int i = 0; i < SIZE; i++) {
if (ud(mt)) {
data[i] = n;
n++;
}
}
//dump(data);
measure(pack_serial, "Serial");
measure(pack_512, "512 bit");
measure(pack_256, "256 bit");
measure(pack_512c, "512 bit + compress");
//dump(result);
//dump(result2);
if (check()) {
std::cout << "OK." << std::endl;
} else {
std::cout << "Failed." << std::endl;
}
}
//------------------------------------------------------------------------
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment