Skip to content

Instantly share code, notes, and snippets.

@DutchGhost
Last active January 2, 2020 13:17
Show Gist options
  • Select an option

  • Save DutchGhost/da521480fed920167e738a0eb91f0541 to your computer and use it in GitHub Desktop.

Select an option

Save DutchGhost/da521480fed920167e738a0eb91f0541 to your computer and use it in GitHub Desktop.
Collatz sequence with simd
mod simd {
use core::arch::x86_64::{
__m128i, _mm_add_epi32, _mm_and_si128, _mm_cmpeq_epi32, _mm_mullo_epi32, _mm_set1_epi32,
_mm_setzero_si128, _mm_sllv_epi32, _mm_srlv_epi32, _mm_sub_epi32, _mm_test_all_ones,
_mm_xor_si128,
};
/// Maps from an array to its corresponding Simd type.
pub unsafe trait SimdArray: Copy {
type Lane: Copy;
type Simd: Copy;
}
unsafe impl SimdArray for [u32; 4] {
type Lane = u32;
type Simd = __m128i;
}
#[repr(C)]
pub struct Simd<A: SimdArray> {
simd: A::Simd,
}
impl<A: SimdArray> Copy for Simd<A> {}
impl<A: SimdArray> Clone for Simd<A> {
fn clone(&self) -> Self {
*self
}
}
pub trait SimdExt {
type Array: SimdArray;
fn from_simd(simd: <Self::Array as SimdArray>::Simd) -> Self;
fn from_array(array: Self::Array) -> Self;
fn into_array(self) -> Self::Array;
fn new() -> Self;
fn add(self, other: Self) -> Self;
fn mul(self, other: Self) -> Self;
fn sub(self, other: Self) -> Self;
fn and(self, other: Self) -> Self;
fn cmp(self, other: Self) -> Self;
fn shr(self, other: Self) -> Self;
fn shl(self, other: Self) -> Self;
fn xor(self, other: Self) -> Self;
fn set1(lane: <Self::Array as SimdArray>::Lane) -> Self;
}
impl SimdExt for Simd<[u32; 4]> {
type Array = [u32; 4];
fn from_array(array: Self::Array) -> Self {
unsafe { Vector { array }.simd }
}
fn from_simd(simd: <Self::Array as SimdArray>::Simd) -> Self {
Self { simd }
}
fn into_array(self) -> Self::Array {
unsafe { Vector { simd: self }.array }
}
#[inline(always)]
fn new() -> Self {
Self::from_simd(unsafe { _mm_setzero_si128() })
}
#[inline(always)]
fn add(self, other: Self) -> Self {
Self::from_simd(unsafe { _mm_add_epi32(self.simd, other.simd) })
}
#[inline(always)]
fn mul(self, other: Self) -> Self {
Self::from_simd(unsafe { _mm_mullo_epi32(self.simd, other.simd) })
}
#[inline(always)]
fn sub(self, other: Self) -> Self {
Self::from_simd(unsafe { _mm_sub_epi32(self.simd, other.simd) })
}
#[inline(always)]
fn and(self, other: Self) -> Self {
Self::from_simd(unsafe { _mm_and_si128(self.simd, other.simd) })
}
#[inline(always)]
fn cmp(self, other: Self) -> Self {
Self::from_simd(unsafe { _mm_cmpeq_epi32(self.simd, other.simd) })
}
#[inline(always)]
fn shr(self, other: Self) -> Self {
Self::from_simd(unsafe { _mm_srlv_epi32(self.simd, other.simd) })
}
#[inline(always)]
fn shl(self, other: Self) -> Self {
Self::from_simd(unsafe { _mm_sllv_epi32(self.simd, other.simd) })
}
#[inline(always)]
fn xor(self, other: Self) -> Self {
Self::from_simd(unsafe { _mm_xor_si128(self.simd, other.simd) })
}
#[inline(always)]
fn set1(lane: <Self::Array as SimdArray>::Lane) -> Self {
Self::from_simd(unsafe { _mm_set1_epi32(lane as i32) })
}
}
pub trait Zero {
const ZERO: Self;
fn is_zero(self) -> bool;
}
pub trait One {
const ALL_ONE: Self;
const ONE: Self;
fn is_all_one(self) -> bool;
}
#[repr(C)]
union Vector<A: SimdArray> {
simd: Simd<A>,
raw_simd: A::Simd,
array: A,
}
impl Zero for Simd<[u32; 4]> {
const ZERO: Self = {
let zero = Vector { array: [0; 4] };
unsafe { zero.simd }
};
#[inline(always)]
fn is_zero(self) -> bool {
self.cmp(Self::ZERO).is_all_one()
}
}
impl One for Simd<[u32; 4]> {
const ONE: Self = {
let one = Vector { array: [1; 4] };
unsafe { one.simd }
};
const ALL_ONE: Self = {
let all_one = Vector {
array: [std::u32::MAX; 4],
};
unsafe { all_one.simd }
};
#[inline(always)]
fn is_all_one(self) -> bool {
unsafe { _mm_test_all_ones(self.simd) == 1 }
}
}
}
pub trait CollatzTransform: Sized {
fn transform(self) -> Option<Self>;
}
use crate::simd::{One, Simd, SimdArray, SimdExt, Zero};
impl CollatzTransform for Simd<[u32; 4]> {
#[inline(always)]
fn transform(mut self) -> Option<Self> {
let whois_one = self.cmp(Self::ONE).and(Self::ONE);
self = self.sub(whois_one);
// (0, 1, 1, 0) for even odd odd even
let odd_bit = Self::ONE.and(self);
// (0, 0xffff, 0xffff, 0) for even odd odd even
let odd_mask = Self::ZERO.sub(odd_bit);
// gives (0xfff, 0xfff, 0xffff, 0xfff) if isnt 0.
// then xor gives (0, 0, 0, 0)
let not_zero_mask = Self::ZERO.cmp(self).xor(Self::ALL_ONE);
self = self.shr(Self::ONE.sub(odd_bit));
self = (self.shl(odd_bit).add(self.and(odd_mask))).add(odd_bit);
self = self.and(not_zero_mask);
if self.is_zero() {
None
} else {
Some(self)
}
}
}
struct CollatzIterator<T> {
current: Option<T>,
}
impl<T: CollatzTransform + Copy> Iterator for CollatzIterator<T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
match self.current {
None => None,
Some(elem) => {
self.current = elem.transform();
Some(elem)
}
}
}
}
impl CollatzTransform for u32 {
fn transform(self) -> Option<Self> {
match self {
1 => None,
n if n & 1 == 0 => Some(n / 2),
n => Some(n * 3 + 1),
}
}
}
fn main() {
use std::time::Instant;
let now = Instant::now();
let collatz = (1..50_000_000)
.step_by(8)
.map(|n| Simd::<[u32; 4]>::from_array([n, n + 2, n + 4, n + 6]))
.map(|simd| {
(
simd,
CollatzIterator {
current: Some(simd),
}
.count(),
)
})
.max_by_key(|&(simd, seq)| seq)
.unwrap();
let elapsed = now.elapsed();
println!(
"{:?}, {:?} in {:?}",
collatz.0.into_array(),
collatz.1,
elapsed
);
let now = Instant::now();
let collatz = (1..50_000_000)
.filter(|n| n & 1 == 1)
.map(|n| (n, CollatzIterator { current: Some(n) }.count()))
.max_by_key(|&(_, seq)| seq)
.unwrap();
let elapsed = now.elapsed();
println!("{:?} {:?} in {:?}", collatz.0, collatz.1, elapsed);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment