Last active
January 3, 2022 22:31
-
-
Save cgiosy/4c4727634fade14f610ef7238ace308d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// #define DONT_USE_SIMD | |
// #pragma GCC optimize("O3") | |
#pragma GCC target("avx2") | |
#include <iostream> | |
#include <x86intrin.h> | |
#define INDEX_BIT_SIZE 14 | |
using i16 = short; | |
using i32 = int; | |
using u32 = unsigned int; | |
using usize = unsigned long; | |
struct Point3D { i16 x, y, z; }; | |
u32 square(i32 x) { return x * x; } | |
int main() { | |
std::ios::sync_with_stdio(false); | |
usize point_num, query_num; | |
std::cin >> point_num >> query_num; | |
alignas(32) i16 points_x[point_num]; | |
alignas(32) i16 points_y[point_num]; | |
alignas(32) i16 points_z[point_num]; | |
for (usize i = 0; i < point_num; i += 1) | |
std::cin >> points_x[i] >> points_y[i] >> points_z[i]; | |
Point3D query_points[query_num]; | |
for (usize q = 0; q < query_num; q += 1) | |
std::cin >> query_points[q].x >> query_points[q].y >> query_points[q].z; | |
usize result[query_num]; | |
#pragma omp parallel for | |
for (usize q = 0; q < query_num; q += 1) { | |
Point3D const point = query_points[q]; | |
__m256i const low16 = _mm256_set1_epi32((1 << 16) - 1); | |
__m256i min_dists = _mm256_set1_epi32(~0); | |
__m256i indices0 = _mm256_set_epi32(0, 2, 4, 6, 8, 10, 12, 14); | |
__m256i indices1 = _mm256_set_epi32(1, 3, 5, 7, 9, 11, 13, 15); | |
usize i = 0; | |
#ifndef DONT_USE_SIMD | |
for (; i + 16 <= point_num; i += 16) { | |
__m256i const x = _mm256_sub_epi16(_mm256_load_si256((__m256i const*)(points_x + i)), _mm256_set1_epi16(point.x)); | |
__m256i const y = _mm256_sub_epi16(_mm256_load_si256((__m256i const*)(points_y + i)), _mm256_set1_epi16(point.y)); | |
__m256i const z = _mm256_sub_epi16(_mm256_load_si256((__m256i const*)(points_z + i)), _mm256_set1_epi16(point.z)); | |
__m256i const xs = _mm256_mullo_epi16(x, x); | |
__m256i const ys = _mm256_mullo_epi16(y, y); | |
__m256i const zs = _mm256_mullo_epi16(z, z); | |
__m256i const dists0 = _mm256_add_epi32(_mm256_and_si256(xs, low16), _mm256_add_epi32(_mm256_and_si256(ys, low16), _mm256_and_si256(zs, low16))); | |
__m256i const dists1 = _mm256_add_epi32(_mm256_srli_epi32(xs, 16), _mm256_add_epi32(_mm256_srli_epi32(ys, 16), _mm256_srli_epi32(zs, 16))); | |
min_dists = _mm256_min_epu32( | |
min_dists, | |
_mm256_min_epu32( | |
_mm256_or_si256(_mm256_slli_epi32(dists0, INDEX_BIT_SIZE), indices0), | |
_mm256_or_si256(_mm256_slli_epi32(dists1, INDEX_BIT_SIZE), indices1) | |
) | |
); | |
indices0 = _mm256_add_epi32(indices0, _mm256_set1_epi32(16)); | |
indices1 = _mm256_add_epi32(indices1, _mm256_set1_epi32(16)); | |
} | |
u32 min_dist = ~0; | |
{ | |
u32 dists[8]; | |
_mm256_storeu_si256((__m256i*)dists, min_dists); | |
for (usize j = 0; j < 8; j++) { | |
u32 const dist = dists[j]; | |
min_dist = min_dist < dist ? min_dist : dist; | |
} | |
} | |
#else | |
u32 min_dist = ~0; | |
#endif | |
for(; i < point_num; i += 1) { | |
u32 const dist = ( | |
square(points_x[i] - point.x) + | |
square(points_y[i] - point.y) + | |
square(points_z[i] - point.z) | |
) << INDEX_BIT_SIZE | i; | |
min_dist = min_dist < dist ? min_dist : dist; | |
} | |
result[q] = min_dist; | |
} | |
u32 ans = 0; | |
for (usize q = 0; q < query_num; q += 1) { | |
u32 const res = result[q]; | |
u32 const dist = res >> INDEX_BIT_SIZE; | |
usize const i = res & (1 << INDEX_BIT_SIZE) - 1; | |
std::cout << points_x[i] << ' ' << points_y[i] << ' ' << points_z[i] << ", " << dist << '\n'; | |
} | |
std::cout << ans; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment