Created
January 5, 2019 16:22
-
-
Save TheZoq2/1e1cf65f067284a258009f87f89e6250 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use rand::{ | |
Rng, | |
thread_rng, | |
distributions::{ | |
Distribution, | |
Uniform | |
} | |
}; | |
use std::time::Instant; | |
use std::thread; | |
use std::sync::Arc; | |
pub fn parallel_sample_gather_sublists<T>(data: &[T], samples: &[T]) -> Vec<Vec<T>> | |
where T: Clone + Ord + std::fmt::Debug | |
{ | |
// Generate the buckets | |
let mut buckets: Vec<Vec<T>> = vec!(); | |
for _ in 0..samples.len() + 1 { | |
buckets.push(vec!()) | |
}; | |
for d in data { | |
let index = match samples.binary_search(d) { | |
Ok(i) => i, | |
Err(i) => i | |
}; | |
buckets[index].push(d.clone()); | |
} | |
for b in buckets.iter_mut() { | |
b.sort(); | |
} | |
buckets | |
} | |
pub fn merge_sublists<T>(sublists: &Vec<&Vec<T>>) -> Vec<T> | |
where T: Clone + Sync + Send + Ord + PartialOrd + std::fmt::Debug | |
{ | |
let mut current_indexes = (0..(sublists.len())) | |
.map(|_| 0) | |
.collect::<Vec<_>>(); | |
let mut result = vec!(); | |
loop { | |
let acc: Option<(T, usize)> = current_indexes.iter() | |
.zip(sublists.iter()) | |
.enumerate() | |
.fold(None, |acc, (outer_index, (inner_index, iter))| { | |
if let (Some(acc), Some(val)) = (acc.clone(), iter.get(*inner_index)) { | |
if val < &acc.0 { | |
Some((val.clone(), outer_index)) | |
} | |
else { | |
Some(acc) | |
} | |
} | |
else if let Some(val) = iter.get(*inner_index) { | |
Some((val.clone(), outer_index)) | |
} | |
else { | |
acc | |
} | |
}); | |
if let Some((smallest_val, smallest_index)) = acc { | |
current_indexes[smallest_index] += 1; | |
result.push(smallest_val) | |
} else { | |
break; | |
} | |
} | |
result | |
} | |
pub fn parallel_sample_sort<T>( | |
num_threads: usize, | |
data: Arc<Vec<u32>>, | |
abs_min: u32, | |
abs_max: u32 | |
) -> Vec<u32> | |
where T: Clone + Ord + Sync + Send | |
{ | |
// Select the samples | |
let mut rng = thread_rng(); | |
let samples = (0..num_threads) | |
.map(|_| rng.gen_range(0, data.len())) | |
.map(|index| data[index].clone()) | |
.collect::<Vec<_>>(); | |
let full_samples = { | |
let mut unsorted = [abs_min].iter() | |
.chain(samples.iter()) | |
.chain([abs_max].iter()) | |
.cloned() | |
.collect::<Vec<_>>(); | |
unsorted.sort(); | |
unsorted | |
}; | |
let mut threads = vec!(); | |
let arced_samples = Arc::new(full_samples); | |
let slice_length = data.len() / num_threads + 1; | |
for thread_id in 0..num_threads { | |
let cloned_data = data.clone(); | |
let arced_samples = arced_samples.clone(); | |
let index_start = slice_length * thread_id; | |
let index_end = (slice_length * (thread_id + 1)).min(data.len()); | |
threads.push(thread::spawn(move || { | |
parallel_sample_gather_sublists( | |
&cloned_data[index_start..index_end], | |
&arced_samples | |
) | |
})); | |
} | |
let sublists = threads.into_iter().map(|t| t.join().unwrap()).collect::<Vec<_>>(); | |
let num_merge_threads = num_threads+2; | |
let mut to_merge = vec!(); | |
for thread_id in 0..num_merge_threads { | |
to_merge.push(vec!()); | |
for i in 0..num_threads { | |
to_merge[thread_id].push(&sublists[i][thread_id]) | |
} | |
} | |
let to_merge = Arc::new(to_merge); | |
let mut result = vec!(); | |
crossbeam_utils::thread::scope(|scope| { | |
let mut threads = vec!(); | |
for thread_id in 0..num_merge_threads { | |
let to_merge = to_merge.clone(); | |
threads.push(scope.spawn(move |_| { | |
merge_sublists(&to_merge[thread_id]) | |
})); | |
} | |
let mut results = threads.into_iter().map(|t| t.join().unwrap()).collect::<Vec<_>>(); | |
for mut r in results.iter_mut() { | |
result.append(&mut r) | |
} | |
}).unwrap(); | |
result | |
} | |
fn main() { | |
let mut rng = rand::thread_rng(); | |
let mut input = Uniform::new_inclusive(1, 10000).sample_iter(&mut rng).take(1_00_000_000).collect::<Vec<_>>(); | |
println!("Input generated"); | |
let parallel_start = Instant::now(); | |
let result = parallel_sample_sort::<u32>(16, Arc::new(input.clone()), 0, std::u32::MAX); | |
let parallel_end = Instant::now(); | |
println!("Sorted in parallel"); | |
let sequential_start = Instant::now(); | |
input.sort(); | |
let sequential_end = Instant::now(); | |
println!("Sorted sequentially"); | |
// println!("got: {:?}\nexp: {:?}", result, input); | |
assert_eq!(result, input); | |
let par_time = parallel_end-parallel_start; | |
let seq_time = sequential_end-sequential_start; | |
println!( | |
"Parallel time: {}, Sequential time: {}", | |
par_time.as_secs() as f32 + par_time.subsec_millis() as f32/ 1000., | |
seq_time.as_secs() as f32 + seq_time.subsec_millis() as f32/ 1000. | |
); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment