Created
September 14, 2020 19:59
-
-
Save BreadFish64/47b7274fae70c5a760b473c2cc68304e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Solution { | |
public: | |
[[gnu::target("avx2")]] int reverse(int x) const { | |
bool negative = x < 0; | |
if (negative) { | |
if (x == std::numeric_limits<int>::min()) | |
return 0; | |
x = -x; | |
} | |
if (x >= 10) { | |
int digit_count = std::log10(x) + 1; | |
// floats can only hold 7 digits comfortably | |
// and I don't feel like dealing with double conversions | |
int aux{0}; | |
if (digit_count >= 8) { | |
aux = (x % 10) * 10000000; | |
x /= 10; | |
--digit_count; | |
if (digit_count >= 8) { | |
aux *= 10; | |
aux += (x % 10) * 10000000; | |
x /= 10; | |
--digit_count; | |
if (digit_count == 8) {\ | |
if (__builtin_mul_overflow(aux, 10, &aux)) { | |
return 0; | |
} | |
aux += (x % 10) * 10000000; | |
x /= 10; | |
--digit_count; | |
} | |
} | |
} | |
__m256i permute = _mm256_sub_epi32(_mm256_set_epi32(0, 1, 2, 3, 4, 5, 6, 7), | |
_mm256_set1_epi32(8 - digit_count)); | |
__m256 ten = _mm256_set1_ps(10.0f); | |
const __m256 powers = | |
_mm256_set_ps(std::pow(10, 7), std::pow(10, 6), std::pow(10, 5), std::pow(10, 4), | |
std::pow(10, 3), std::pow(10, 2), std::pow(10, 1), std::pow(10, 0)); | |
__m256 digits = _mm256_set1_ps(static_cast<float>(x)); | |
// extract digit | |
digits = _mm256_div_ps(digits, powers); | |
digits = _mm256_sub_ps(digits, | |
_mm256_mul_ps(_mm256_floor_ps(_mm256_div_ps(digits, ten)), ten)); | |
digits = _mm256_floor_ps(digits); | |
digits = _mm256_permutevar8x32_ps(digits, permute); | |
digits = _mm256_mul_ps(digits, powers); | |
digits = _mm256_hadd_ps(digits, /* dummy */ powers); | |
// [[garbage, garbage, 8 + 7, 6 + 5, garbage, garbage, 4 + 3, 2 + 1] | |
__m128 vf4 = _mm256_extractf128_ps(_mm256_permute4x64_pd(digits, 0b11'01'10'00), 0); | |
vf4 = _mm_hadd_ps(vf4, vf4); | |
__m128d vd2 = _mm_cvtps_pd(vf4); | |
vd2 = _mm_hadd_pd(vd2, vd2); | |
x = static_cast<int>(_mm_cvtsd_f64(vd2)); | |
if (__builtin_add_overflow(x, aux, &x)) | |
return 0; | |
} | |
if (negative) | |
x = -x; | |
return x; | |
} | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment