Last active
January 2, 2020 13:17
-
-
Save DutchGhost/da521480fed920167e738a0eb91f0541 to your computer and use it in GitHub Desktop.
Collatz sequence with simd
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| mod simd { | |
| use core::arch::x86_64::{ | |
| __m128i, _mm_add_epi32, _mm_and_si128, _mm_cmpeq_epi32, _mm_mullo_epi32, _mm_set1_epi32, | |
| _mm_setzero_si128, _mm_sllv_epi32, _mm_srlv_epi32, _mm_sub_epi32, _mm_test_all_ones, | |
| _mm_xor_si128, | |
| }; | |
| /// Maps from an array to its corresponding Simd type. | |
| pub unsafe trait SimdArray: Copy { | |
| type Lane: Copy; | |
| type Simd: Copy; | |
| } | |
| unsafe impl SimdArray for [u32; 4] { | |
| type Lane = u32; | |
| type Simd = __m128i; | |
| } | |
| #[repr(C)] | |
| pub struct Simd<A: SimdArray> { | |
| simd: A::Simd, | |
| } | |
| impl<A: SimdArray> Copy for Simd<A> {} | |
| impl<A: SimdArray> Clone for Simd<A> { | |
| fn clone(&self) -> Self { | |
| *self | |
| } | |
| } | |
| pub trait SimdExt { | |
| type Array: SimdArray; | |
| fn from_simd(simd: <Self::Array as SimdArray>::Simd) -> Self; | |
| fn from_array(array: Self::Array) -> Self; | |
| fn into_array(self) -> Self::Array; | |
| fn new() -> Self; | |
| fn add(self, other: Self) -> Self; | |
| fn mul(self, other: Self) -> Self; | |
| fn sub(self, other: Self) -> Self; | |
| fn and(self, other: Self) -> Self; | |
| fn cmp(self, other: Self) -> Self; | |
| fn shr(self, other: Self) -> Self; | |
| fn shl(self, other: Self) -> Self; | |
| fn xor(self, other: Self) -> Self; | |
| fn set1(lane: <Self::Array as SimdArray>::Lane) -> Self; | |
| } | |
| impl SimdExt for Simd<[u32; 4]> { | |
| type Array = [u32; 4]; | |
| fn from_array(array: Self::Array) -> Self { | |
| unsafe { Vector { array }.simd } | |
| } | |
| fn from_simd(simd: <Self::Array as SimdArray>::Simd) -> Self { | |
| Self { simd } | |
| } | |
| fn into_array(self) -> Self::Array { | |
| unsafe { Vector { simd: self }.array } | |
| } | |
| #[inline(always)] | |
| fn new() -> Self { | |
| Self::from_simd(unsafe { _mm_setzero_si128() }) | |
| } | |
| #[inline(always)] | |
| fn add(self, other: Self) -> Self { | |
| Self::from_simd(unsafe { _mm_add_epi32(self.simd, other.simd) }) | |
| } | |
| #[inline(always)] | |
| fn mul(self, other: Self) -> Self { | |
| Self::from_simd(unsafe { _mm_mullo_epi32(self.simd, other.simd) }) | |
| } | |
| #[inline(always)] | |
| fn sub(self, other: Self) -> Self { | |
| Self::from_simd(unsafe { _mm_sub_epi32(self.simd, other.simd) }) | |
| } | |
| #[inline(always)] | |
| fn and(self, other: Self) -> Self { | |
| Self::from_simd(unsafe { _mm_and_si128(self.simd, other.simd) }) | |
| } | |
| #[inline(always)] | |
| fn cmp(self, other: Self) -> Self { | |
| Self::from_simd(unsafe { _mm_cmpeq_epi32(self.simd, other.simd) }) | |
| } | |
| #[inline(always)] | |
| fn shr(self, other: Self) -> Self { | |
| Self::from_simd(unsafe { _mm_srlv_epi32(self.simd, other.simd) }) | |
| } | |
| #[inline(always)] | |
| fn shl(self, other: Self) -> Self { | |
| Self::from_simd(unsafe { _mm_sllv_epi32(self.simd, other.simd) }) | |
| } | |
| #[inline(always)] | |
| fn xor(self, other: Self) -> Self { | |
| Self::from_simd(unsafe { _mm_xor_si128(self.simd, other.simd) }) | |
| } | |
| #[inline(always)] | |
| fn set1(lane: <Self::Array as SimdArray>::Lane) -> Self { | |
| Self::from_simd(unsafe { _mm_set1_epi32(lane as i32) }) | |
| } | |
| } | |
| pub trait Zero { | |
| const ZERO: Self; | |
| fn is_zero(self) -> bool; | |
| } | |
| pub trait One { | |
| const ALL_ONE: Self; | |
| const ONE: Self; | |
| fn is_all_one(self) -> bool; | |
| } | |
| #[repr(C)] | |
| union Vector<A: SimdArray> { | |
| simd: Simd<A>, | |
| raw_simd: A::Simd, | |
| array: A, | |
| } | |
| impl Zero for Simd<[u32; 4]> { | |
| const ZERO: Self = { | |
| let zero = Vector { array: [0; 4] }; | |
| unsafe { zero.simd } | |
| }; | |
| #[inline(always)] | |
| fn is_zero(self) -> bool { | |
| self.cmp(Self::ZERO).is_all_one() | |
| } | |
| } | |
| impl One for Simd<[u32; 4]> { | |
| const ONE: Self = { | |
| let one = Vector { array: [1; 4] }; | |
| unsafe { one.simd } | |
| }; | |
| const ALL_ONE: Self = { | |
| let all_one = Vector { | |
| array: [std::u32::MAX; 4], | |
| }; | |
| unsafe { all_one.simd } | |
| }; | |
| #[inline(always)] | |
| fn is_all_one(self) -> bool { | |
| unsafe { _mm_test_all_ones(self.simd) == 1 } | |
| } | |
| } | |
| } | |
| pub trait CollatzTransform: Sized { | |
| fn transform(self) -> Option<Self>; | |
| } | |
| use crate::simd::{One, Simd, SimdArray, SimdExt, Zero}; | |
| impl CollatzTransform for Simd<[u32; 4]> { | |
| #[inline(always)] | |
| fn transform(mut self) -> Option<Self> { | |
| let whois_one = self.cmp(Self::ONE).and(Self::ONE); | |
| self = self.sub(whois_one); | |
| // (0, 1, 1, 0) for even odd odd even | |
| let odd_bit = Self::ONE.and(self); | |
| // (0, 0xffff, 0xffff, 0) for even odd odd even | |
| let odd_mask = Self::ZERO.sub(odd_bit); | |
| // gives (0xfff, 0xfff, 0xffff, 0xfff) if isnt 0. | |
| // then xor gives (0, 0, 0, 0) | |
| let not_zero_mask = Self::ZERO.cmp(self).xor(Self::ALL_ONE); | |
| self = self.shr(Self::ONE.sub(odd_bit)); | |
| self = (self.shl(odd_bit).add(self.and(odd_mask))).add(odd_bit); | |
| self = self.and(not_zero_mask); | |
| if self.is_zero() { | |
| None | |
| } else { | |
| Some(self) | |
| } | |
| } | |
| } | |
| struct CollatzIterator<T> { | |
| current: Option<T>, | |
| } | |
| impl<T: CollatzTransform + Copy> Iterator for CollatzIterator<T> { | |
| type Item = T; | |
| fn next(&mut self) -> Option<Self::Item> { | |
| match self.current { | |
| None => None, | |
| Some(elem) => { | |
| self.current = elem.transform(); | |
| Some(elem) | |
| } | |
| } | |
| } | |
| } | |
| impl CollatzTransform for u32 { | |
| fn transform(self) -> Option<Self> { | |
| match self { | |
| 1 => None, | |
| n if n & 1 == 0 => Some(n / 2), | |
| n => Some(n * 3 + 1), | |
| } | |
| } | |
| } | |
| fn main() { | |
| use std::time::Instant; | |
| let now = Instant::now(); | |
| let collatz = (1..50_000_000) | |
| .step_by(8) | |
| .map(|n| Simd::<[u32; 4]>::from_array([n, n + 2, n + 4, n + 6])) | |
| .map(|simd| { | |
| ( | |
| simd, | |
| CollatzIterator { | |
| current: Some(simd), | |
| } | |
| .count(), | |
| ) | |
| }) | |
| .max_by_key(|&(simd, seq)| seq) | |
| .unwrap(); | |
| let elapsed = now.elapsed(); | |
| println!( | |
| "{:?}, {:?} in {:?}", | |
| collatz.0.into_array(), | |
| collatz.1, | |
| elapsed | |
| ); | |
| let now = Instant::now(); | |
| let collatz = (1..50_000_000) | |
| .filter(|n| n & 1 == 1) | |
| .map(|n| (n, CollatzIterator { current: Some(n) }.count())) | |
| .max_by_key(|&(_, seq)| seq) | |
| .unwrap(); | |
| let elapsed = now.elapsed(); | |
| println!("{:?} {:?} in {:?}", collatz.0, collatz.1, elapsed); | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment