Skip to content

Instantly share code, notes, and snippets.

@lovely-error
Last active March 6, 2024 07:52
Show Gist options
  • Save lovely-error/b3dfc6f53486782628494bc4ad2670a8 to your computer and use it in GitHub Desktop.
Save lovely-error/b3dfc6f53486782628494bc4ad2670a8 to your computer and use it in GitHub Desktop.
Conditional execution without branches on gpu
// this kernel writes a value only if invoaction index is even *_*
// when _flag is true, the store adress gets computed to be beyound (1 << 48) byte location .
// apparently stores beyound valid adress space on gpu... just become noops!
#define store_if_true(Ty, addr, val, flag) \
*((__global Ty*)(((unsigned long)addr) + (((unsigned long)(flag == 0)) * ((1LU << 48) - ((unsigned long)addr))))) = val;
__kernel void kern(__global unsigned int* buffer) {
unsigned long ix = get_global_id(0);
buffer[ix] = 0;
unsigned long even = (ix & 1) == 0;
store_if_true(unsigned int, &buffer[ix], -1, even);
}
use ocl::{Buffer, MemFlags};
fn main() {
const BUFFER_SIZE: usize = 64 * 1024 * 32;
let platforms = ocl::Platform::list();
let platform = platforms.into_iter().nth(0).unwrap();
let device = ocl::Device::first(platform).unwrap();
let ctx = ocl::Context::builder().platform(platform).devices(device).build().unwrap();
let q = ocl::Queue::new(&ctx, device, None).unwrap();
let programm = {
ocl::Program::builder().source_file("kern.cl").build(&ctx).unwrap()
};
let buffer = unsafe { Buffer::<u32>::new(q.clone(), MemFlags::empty(), BUFFER_SIZE, None).unwrap() };
let kern = ocl::Kernel::builder()
.program(&programm)
.name("kern")
.arg(&buffer)
.build()
.unwrap();
unsafe { kern.cmd().global_work_size(BUFFER_SIZE).queue(&q).enq().unwrap() };
let mut vec = vec![0u32; buffer.len()];
buffer.read(&mut vec).enq().unwrap();
let mut ix = 0;
for item in vec {
if ix & 0b1 != 0 {
assert!(item == 0)
} else {
assert!(item == u32::MAX);
}
ix += 1;
}
}
// conditionally store to either of two memory locations
#define store_phi(addr_true, addr_false, val, cond) \
*(addr_true + (((cond == 0)) * (addr_false - addr_true))) = val;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment