Created
July 24, 2012 04:35
-
-
Save inducer/3168065 to your computer and use it in GitHub Desktop.
PyOpenCL scan code generator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#define K ${k_group_size} | |
KERNEL | |
REQD_WG_SIZE(WG_SIZE, 1, 1) | |
void ${name_prefix}_scan_intervals( | |
${argument_signature}, | |
GLOBAL_MEM scan_type *partial_scan_buffer, | |
const index_type N, | |
const index_type interval_size | |
%if is_first_level: | |
, GLOBAL_MEM scan_type *interval_results | |
%endif | |
%if is_segmented and is_first_level: | |
/* NO_SEG_BOUNDARY if no segment boundary in interval. | |
Otherwise, index relative to interval beginning. | |
*/ | |
, GLOBAL_MEM index_type *g_first_segment_start_in_interval | |
%endif | |
%if store_segment_start_flags: | |
, GLOBAL_MEM char *g_segment_start_flags | |
%endif | |
) | |
{ | |
// padded in WG_SIZE to avoid bank conflicts | |
// index K in first dimension used for carry storage | |
LOCAL_MEM scan_type ldata[K + 1][WG_SIZE + 1]; | |
%if is_segmented: | |
LOCAL_MEM char l_segment_start_flags[K][WG_SIZE]; | |
LOCAL_MEM index_type l_first_segment_start_in_subtree[WG_SIZE]; | |
// only relevant/populated for local id 0 | |
index_type first_segment_start_in_interval = NO_SEG_BOUNDARY; | |
index_type first_segment_start_in_k_group, first_segment_start_in_subtree; | |
%endif | |
// {{{ set up local data for input_fetch_exprs if any of them are stenciled | |
<% | |
fetch_expr_offsets = {} | |
for name, arg_name, ife_offset in input_fetch_exprs: | |
fetch_expr_offsets.setdefault(arg_name, set()).add(ife_offset) | |
local_fetch_expr_args = set( | |
arg_name | |
for arg_name, ife_offsets in fetch_expr_offsets.iteritems() | |
if -1 in ife_offsets or len(ife_offsets) > 1) | |
%> | |
%for arg_name in local_fetch_expr_args: | |
LOCAL_MEM ${arg_ctypes[arg_name]} l_${arg_name}[WG_SIZE*K]; | |
%endfor | |
// }}} | |
const index_type interval_begin = interval_size * GID_0; | |
const index_type interval_end = min(interval_begin + interval_size, N); | |
const index_type unit_size = K * WG_SIZE; | |
index_type unit_base = interval_begin; | |
%for is_tail in [False, True]: | |
%if not is_tail: | |
for(; unit_base + unit_size <= interval_end; unit_base += unit_size) | |
%else: | |
if (unit_base < interval_end) | |
%endif | |
{ | |
// {{{ carry out input_fetch_exprs | |
// (if there are ones that need to be fetched into local) | |
%if local_fetch_expr_args: | |
for(index_type k = 0; k < K; k++) | |
{ | |
const index_type offset = k*WG_SIZE + LID_0; | |
const index_type read_i = unit_base + offset; | |
%for arg_name in local_fetch_expr_args: | |
%if is_tail: | |
if (read_i < interval_end) | |
%endif | |
{ | |
l_${arg_name}[offset] = ${arg_name}[read_i]; | |
} | |
%endfor | |
} | |
local_barrier(); | |
%endif | |
// }}} | |
// {{{ read a unit's worth of data from global | |
for(index_type k = 0; k < K; k++) | |
{ | |
const index_type offset = k*WG_SIZE + LID_0; | |
const index_type read_i = unit_base + offset; | |
%if is_tail: | |
if (read_i < interval_end) | |
%endif | |
{ | |
%for name, arg_name, ife_offset in input_fetch_exprs: | |
${arg_ctypes[arg_name]} ${name}; | |
%if arg_name in local_fetch_expr_args: | |
if (offset + ${ife_offset} >= 0) | |
${name} = l_${arg_name}[offset + ${ife_offset}]; | |
else if (read_i + ${ife_offset} >= 0) | |
${name} = ${arg_name}[read_i + ${ife_offset}]; | |
/* | |
else | |
if out of bounds, name is left undefined */ | |
%else: | |
// ${arg_name} gets fetched directly from global | |
${name} = ${arg_name}[read_i]; | |
%endif | |
%endfor | |
scan_type scan_value = INPUT_EXPR(read_i); | |
const index_type o_mod_k = offset % K; | |
const index_type o_div_k = offset / K; | |
ldata[o_mod_k][offset / K] = scan_value; | |
%if is_segmented: | |
bool is_seg_start = IS_SEG_START(read_i, scan_value); | |
l_segment_start_flags[o_mod_k][o_div_k] = is_seg_start; | |
%endif | |
%if store_segment_start_flags: | |
g_segment_start_flags[read_i] = is_seg_start; | |
%endif | |
} | |
} | |
// }}} | |
// {{{ carry in from previous unit, if applicable | |
%if is_segmented: | |
local_barrier(); | |
first_segment_start_in_k_group = NO_SEG_BOUNDARY; | |
if (l_segment_start_flags[0][LID_0]) | |
first_segment_start_in_k_group = unit_base + K*LID_0; | |
%endif | |
if (LID_0 == 0 && unit_base != interval_begin | |
%if is_segmented: | |
&& !l_segment_start_flags[0][0] | |
%endif | |
) | |
{ | |
ldata[0][0] = SCAN_EXPR(ldata[K][WG_SIZE - 1], ldata[0][0]); | |
} | |
// }}} | |
local_barrier(); | |
// {{{ scan along k (sequentially in each work item) | |
scan_type sum = ldata[0][LID_0]; | |
%if is_tail: | |
const index_type offset_end = interval_end - unit_base; | |
%endif | |
for(index_type k = 1; k < K; k++) | |
{ | |
%if is_tail: | |
if (K * LID_0 + k < offset_end) | |
%endif | |
{ | |
scan_type tmp = ldata[k][LID_0]; | |
index_type seq_i = unit_base + K*LID_0 + k; | |
%if is_segmented: | |
if (l_segment_start_flags[k][LID_0]) | |
{ | |
first_segment_start_in_k_group = min( | |
first_segment_start_in_k_group, | |
seq_i); | |
sum = tmp; | |
} | |
else | |
%endif | |
sum = SCAN_EXPR(sum, tmp); | |
ldata[k][LID_0] = sum; | |
} | |
} | |
// }}} | |
// store carry in out-of-bounds (padding) array entry (index K) in the K direction | |
ldata[K][LID_0] = sum; | |
%if is_segmented: | |
l_first_segment_start_in_subtree[LID_0] = first_segment_start_in_k_group; | |
%endif | |
local_barrier(); | |
// {{{ tree-based local parallel scan | |
// This tree-based scan works as follows: | |
// - Each work item adds the previous item to its current state | |
// - barrier | |
// - Each work item adds in the item from two positions to the left | |
// - barrier | |
// - Each work item adds in the item from four positions to the left | |
// ... | |
// At the end, each item has summed all prior items. | |
// across k groups, along local id | |
// (uses out-of-bounds k=K array entry for storage) | |
scan_type val = ldata[K][LID_0]; | |
<% scan_offset = 1 %> | |
% while scan_offset <= wg_size: | |
// {{{ reads from local allowed, writes to local not allowed | |
if (LID_0 >= ${scan_offset}) | |
{ | |
scan_type tmp = ldata[K][LID_0 - ${scan_offset}]; | |
% if is_tail: | |
if (K*LID_0 < offset_end) | |
% endif | |
{ | |
%if is_segmented: | |
if (l_first_segment_start_in_subtree[LID_0] == NO_SEG_BOUNDARY) | |
val = SCAN_EXPR(tmp, val); | |
%else: | |
val = SCAN_EXPR(tmp, val); | |
%endif | |
} | |
%if is_segmented: | |
// Prepare for l_first_segment_start_in_subtree, below. | |
// Note that this update must take place *even* if we're | |
// out of bounds. | |
first_segment_start_in_subtree = min( | |
l_first_segment_start_in_subtree[LID_0], | |
l_first_segment_start_in_subtree[LID_0 - ${scan_offset}]); | |
%endif | |
} | |
%if is_segmented: | |
else | |
{ | |
first_segment_start_in_subtree = | |
l_first_segment_start_in_subtree[LID_0]; | |
} | |
%endif | |
// }}} | |
local_barrier(); | |
// {{{ writes to local allowed, reads from local not allowed | |
ldata[K][LID_0] = val; | |
%if is_segmented: | |
l_first_segment_start_in_subtree[LID_0] = | |
first_segment_start_in_subtree; | |
%endif | |
// }}} | |
local_barrier(); | |
%if 0: | |
if (LID_0 == 0) | |
{ | |
printf("${scan_offset}: "); | |
for (int i = 0; i < WG_SIZE; ++i) | |
{ | |
if (l_first_segment_start_in_subtree[i] == NO_SEG_BOUNDARY) | |
printf("- "); | |
else | |
printf("%d ", l_first_segment_start_in_subtree[i]); | |
} | |
printf("\n"); | |
} | |
%endif | |
<% scan_offset *= 2 %> | |
% endwhile | |
// }}} | |
// {{{ update local values | |
if (LID_0 > 0) | |
{ | |
sum = ldata[K][LID_0 - 1]; | |
for(index_type k = 0; k < K; k++) | |
{ | |
bool do_update = true; | |
%if is_tail: | |
do_update = K * LID_0 + k < offset_end; | |
%endif | |
%if is_segmented: | |
do_update = unit_base + K * LID_0 + k | |
< first_segment_start_in_k_group; | |
%endif | |
if (do_update) | |
{ | |
scan_type tmp = ldata[k][LID_0]; | |
ldata[k][LID_0] = SCAN_EXPR(sum, tmp); | |
} | |
} | |
} | |
%if is_segmented: | |
if (LID_0 == 0) | |
{ | |
// update interval-wide first-seg variable from current unit | |
first_segment_start_in_interval = min( | |
first_segment_start_in_interval, | |
l_first_segment_start_in_subtree[WG_SIZE-1]); | |
} | |
%endif | |
// }}} | |
local_barrier(); | |
// {{{ write data | |
for (index_type k = 0; k < K; k++) | |
{ | |
const index_type offset = k*WG_SIZE + LID_0; | |
%if is_tail: | |
if (unit_base + offset < interval_end) | |
%endif | |
{ | |
partial_scan_buffer[unit_base + offset] = | |
ldata[offset % K][offset / K]; | |
} | |
} | |
// }}} | |
local_barrier(); | |
} | |
% endfor | |
// write interval sum | |
if (LID_0 == 0) | |
{ | |
%if is_first_level: | |
interval_results[GID_0] = partial_scan_buffer[interval_end - 1]; | |
%endif | |
%if is_segmented and is_first_level: | |
g_first_segment_start_in_interval[GID_0] = first_segment_start_in_interval; | |
%endif | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment