Last active
December 8, 2018 23:03
-
-
Save JulianKnodt/df1c3326b2417a900c8f29e6e303a6a3 to your computer and use it in GitHub Desktop.
Basic simd instructions in rust
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#[cfg(test)] | |
mod tests { | |
use average; | |
#[test] | |
fn test_average() { | |
let items = vec!(3.0, 4.0, 5.0); | |
unsafe { | |
assert_eq!(average::average(&items), 4.0); | |
} | |
} | |
#[test] | |
fn test_large_average() { | |
let b = 10000; | |
let items: Vec<f64> = (1..b).map(|x| x as f64).collect(); | |
unsafe { | |
assert_eq!(average::average(&items), (b as f64)/2.0); | |
} | |
} | |
#[test] | |
fn test_sum() { | |
let b:i64 = 4; | |
let items: Vec<f64> = (1..b).map(|x| x as f64).collect(); | |
unsafe { | |
println!("here"); | |
assert_eq!(average::kahan(&items), (b * (b-1)) as f64/2.0); | |
} | |
} | |
#[test] | |
fn test_large_sum() { | |
let b:i64 = 9999999; | |
let items: Vec<f64> = (1..b).map(|x| x as f64).collect(); | |
unsafe { | |
println!("here"); | |
assert_eq!(average::kahan(&items), (b * (b-1)) as f64/2.0); | |
} | |
} | |
} | |
#[cfg(any(target_arch = "x86", target_arch="x86_64"))] | |
mod average { | |
#[cfg(target_arch="x86")] | |
use std::arch::x86::*; | |
#[cfg(target_arch="x86_64")] | |
use std::arch::x86_64::*; | |
use std::mem; | |
#[target_feature(enable="avx")] | |
pub unsafe fn average(s: &Vec<f64>) -> f64 { | |
let sum4 = s.chunks(4).fold(_mm256_setzero_pd(), |acc, next| match next { | |
&[a, b, c, d] => _mm256_add_pd(_mm256_set_pd(a,b,c,d), acc), | |
&[a, b, c] => _mm256_add_pd(_mm256_set_pd(a,b,c,0.0), acc), | |
&[a, b] => _mm256_add_pd(_mm256_set_pd(a,b,0.0,0.0), acc), | |
&[a] => _mm256_add_pd(_mm256_set_pd(a,0.0,0.0,0.0), acc), | |
_ => unreachable!(), | |
}); | |
let (a, b, c, d) : (f64, f64, f64, f64) = mem::transmute(sum4); | |
return (a + b + c + d)/(s.len() as f64); | |
} | |
#[target_feature(enable="avx")] | |
pub unsafe fn kahan(s: &Vec<f64>) -> f64 { | |
let mut lost = _mm256_setzero_pd(); | |
let parts = s.chunks(4).fold(_mm256_setzero_pd(), |sum, next| { | |
let a = *next.get_unchecked(0); | |
let b = *next.get(1).unwrap_or(&0.0); | |
let c = *next.get(2).unwrap_or(&0.0); | |
let d = *next.get(3).unwrap_or(&0.0); | |
let next_lanes = _mm256_set_pd(a,b,c,d); | |
let compensated = _mm256_sub_pd(next_lanes, lost); | |
let with_error = _mm256_add_pd(sum, compensated); | |
lost = _mm256_sub_pd(_mm256_sub_pd(with_error, sum), compensated); | |
return with_error; | |
}); | |
let (a, b, c, d) : (f64, f64, f64, f64) = mem::transmute(parts); | |
return a + b + c + d; | |
} | |
} | |
#[cfg(not(any(target_arch = "x86", target_arch="x86_64")))] | |
mod average { | |
fn average(s: &Vec<f64>) -> f64 { | |
s.iter().sum()/s.len() | |
} | |
fn kahan(s: &Vec<f64>) -> f64 { | |
s.fold((0, 0), |(sum, lost), next| { | |
let compensated = next - lost; | |
let with_error = sum + compensated; | |
(with_error, with_error - sum - compensated); | |
}).0 | |
} | |
} | |
// based off of | |
// https://medium.com/@Razican/learning-simd-with-rust-by-finding-planets-b85ccfb724c3 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment