Last active
October 21, 2021 16:36
-
-
Save ro99/8346fd0536b5b37dfb3cd79060a77fe7 to your computer and use it in GitHub Desktop.
tests with rayon (about 1.3 speedup)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
unsafe fn run_with_scratch_space_parallel( | |
&self, | |
m: usize, | |
n: usize, | |
non_linear: &[FusedSpec], | |
) -> anyhow::Result<()> { | |
let mr = K::mr(); | |
let nr = K::nr(); | |
let mut rows: Vec<usize> = (0..n / nr).collect(); | |
let size = rows.len() / 32 + rows.len() % 32; | |
let ctx: ThreadLocalCtx<Box<dyn ScratchSpace>, _> = ThreadLocalCtx::new(|| { | |
let mut scratch = self.allocate_scratch_space(); | |
scratch | |
.downcast_mut::<ScratchSpaceFusedNonLinear<TI>>() | |
.unwrap() | |
.prepare::<K>(non_linear); | |
scratch | |
}); | |
for ia in 0..m / mr { | |
rows.par_chunks_mut(size).for_each(|row_chunk|{ | |
let row_chunk = row_chunk.to_owned(); | |
let mut scratch = ctx.get(); | |
let scratch = scratch.downcast_mut::<ScratchSpaceFusedNonLinear<TI>>().unwrap(); | |
for ib in row_chunk { | |
scratch.for_valid_tile::<K>(&non_linear, ia, ib); | |
let err = K::kernel(&scratch.uspecs()); | |
debug_assert_eq!(err, 0, "Kernel return error {}", err); | |
} | |
}); | |
} | |
if m % mr != 0 { | |
rows.par_chunks_mut(size).for_each(|row_chunk|{ | |
let row_chunk = row_chunk.to_owned(); | |
let mut scratch = ctx.get(); | |
let scratch = scratch.downcast_mut::<ScratchSpaceFusedNonLinear<TI>>().unwrap(); | |
for ib in row_chunk { | |
scratch.for_border_tile::<K>(&non_linear, m / mr, ib); | |
let err = K::kernel(&scratch.uspecs()); | |
debug_assert_eq!(err, 0, "Kernel return error {}", err); | |
scratch.postprocess_tile::<K>(&non_linear, m / mr, ib, m % mr, nr); | |
} | |
}); | |
if n % nr != 0 { | |
let mut scratch = ctx.get(); | |
let scratch = scratch.downcast_mut::<ScratchSpaceFusedNonLinear<TI>>().unwrap(); | |
scratch.for_border_tile::<K>(&non_linear, m / mr, n / nr); | |
let err = K::kernel(&scratch.uspecs()); | |
debug_assert_eq!(err, 0, "Kernel return error {}", err); | |
scratch.postprocess_tile::<K>(&non_linear, m / mr, n / nr, m % mr, n % nr); | |
} | |
} | |
Ok(()) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment