Skip to content

Instantly share code, notes, and snippets.

@TheZoq2
Created January 5, 2019 16:22
Show Gist options
  • Save TheZoq2/1e1cf65f067284a258009f87f89e6250 to your computer and use it in GitHub Desktop.
Save TheZoq2/1e1cf65f067284a258009f87f89e6250 to your computer and use it in GitHub Desktop.
use rand::{
Rng,
thread_rng,
distributions::{
Distribution,
Uniform
}
};
use std::time::Instant;
use std::thread;
use std::sync::Arc;
pub fn parallel_sample_gather_sublists<T>(data: &[T], samples: &[T]) -> Vec<Vec<T>>
where T: Clone + Ord + std::fmt::Debug
{
// Generate the buckets
let mut buckets: Vec<Vec<T>> = vec!();
for _ in 0..samples.len() + 1 {
buckets.push(vec!())
};
for d in data {
let index = match samples.binary_search(d) {
Ok(i) => i,
Err(i) => i
};
buckets[index].push(d.clone());
}
for b in buckets.iter_mut() {
b.sort();
}
buckets
}
pub fn merge_sublists<T>(sublists: &Vec<&Vec<T>>) -> Vec<T>
where T: Clone + Sync + Send + Ord + PartialOrd + std::fmt::Debug
{
let mut current_indexes = (0..(sublists.len()))
.map(|_| 0)
.collect::<Vec<_>>();
let mut result = vec!();
loop {
let acc: Option<(T, usize)> = current_indexes.iter()
.zip(sublists.iter())
.enumerate()
.fold(None, |acc, (outer_index, (inner_index, iter))| {
if let (Some(acc), Some(val)) = (acc.clone(), iter.get(*inner_index)) {
if val < &acc.0 {
Some((val.clone(), outer_index))
}
else {
Some(acc)
}
}
else if let Some(val) = iter.get(*inner_index) {
Some((val.clone(), outer_index))
}
else {
acc
}
});
if let Some((smallest_val, smallest_index)) = acc {
current_indexes[smallest_index] += 1;
result.push(smallest_val)
} else {
break;
}
}
result
}
pub fn parallel_sample_sort<T>(
num_threads: usize,
data: Arc<Vec<u32>>,
abs_min: u32,
abs_max: u32
) -> Vec<u32>
where T: Clone + Ord + Sync + Send
{
// Select the samples
let mut rng = thread_rng();
let samples = (0..num_threads)
.map(|_| rng.gen_range(0, data.len()))
.map(|index| data[index].clone())
.collect::<Vec<_>>();
let full_samples = {
let mut unsorted = [abs_min].iter()
.chain(samples.iter())
.chain([abs_max].iter())
.cloned()
.collect::<Vec<_>>();
unsorted.sort();
unsorted
};
let mut threads = vec!();
let arced_samples = Arc::new(full_samples);
let slice_length = data.len() / num_threads + 1;
for thread_id in 0..num_threads {
let cloned_data = data.clone();
let arced_samples = arced_samples.clone();
let index_start = slice_length * thread_id;
let index_end = (slice_length * (thread_id + 1)).min(data.len());
threads.push(thread::spawn(move || {
parallel_sample_gather_sublists(
&cloned_data[index_start..index_end],
&arced_samples
)
}));
}
let sublists = threads.into_iter().map(|t| t.join().unwrap()).collect::<Vec<_>>();
let num_merge_threads = num_threads+2;
let mut to_merge = vec!();
for thread_id in 0..num_merge_threads {
to_merge.push(vec!());
for i in 0..num_threads {
to_merge[thread_id].push(&sublists[i][thread_id])
}
}
let to_merge = Arc::new(to_merge);
let mut result = vec!();
crossbeam_utils::thread::scope(|scope| {
let mut threads = vec!();
for thread_id in 0..num_merge_threads {
let to_merge = to_merge.clone();
threads.push(scope.spawn(move |_| {
merge_sublists(&to_merge[thread_id])
}));
}
let mut results = threads.into_iter().map(|t| t.join().unwrap()).collect::<Vec<_>>();
for mut r in results.iter_mut() {
result.append(&mut r)
}
}).unwrap();
result
}
fn main() {
let mut rng = rand::thread_rng();
let mut input = Uniform::new_inclusive(1, 10000).sample_iter(&mut rng).take(1_00_000_000).collect::<Vec<_>>();
println!("Input generated");
let parallel_start = Instant::now();
let result = parallel_sample_sort::<u32>(16, Arc::new(input.clone()), 0, std::u32::MAX);
let parallel_end = Instant::now();
println!("Sorted in parallel");
let sequential_start = Instant::now();
input.sort();
let sequential_end = Instant::now();
println!("Sorted sequentially");
// println!("got: {:?}\nexp: {:?}", result, input);
assert_eq!(result, input);
let par_time = parallel_end-parallel_start;
let seq_time = sequential_end-sequential_start;
println!(
"Parallel time: {}, Sequential time: {}",
par_time.as_secs() as f32 + par_time.subsec_millis() as f32/ 1000.,
seq_time.as_secs() as f32 + seq_time.subsec_millis() as f32/ 1000.
);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment