Skip to content

Instantly share code, notes, and snippets.

@inducer
Created July 24, 2012 04:35
Show Gist options
  • Save inducer/3168065 to your computer and use it in GitHub Desktop.
Save inducer/3168065 to your computer and use it in GitHub Desktop.
PyOpenCL scan code generator
#define K ${k_group_size}
KERNEL
REQD_WG_SIZE(WG_SIZE, 1, 1)
void ${name_prefix}_scan_intervals(
${argument_signature},
GLOBAL_MEM scan_type *partial_scan_buffer,
const index_type N,
const index_type interval_size
%if is_first_level:
, GLOBAL_MEM scan_type *interval_results
%endif
%if is_segmented and is_first_level:
/* NO_SEG_BOUNDARY if no segment boundary in interval.
Otherwise, index relative to interval beginning.
*/
, GLOBAL_MEM index_type *g_first_segment_start_in_interval
%endif
%if store_segment_start_flags:
, GLOBAL_MEM char *g_segment_start_flags
%endif
)
{
// padded in WG_SIZE to avoid bank conflicts
// index K in first dimension used for carry storage
LOCAL_MEM scan_type ldata[K + 1][WG_SIZE + 1];
%if is_segmented:
LOCAL_MEM char l_segment_start_flags[K][WG_SIZE];
LOCAL_MEM index_type l_first_segment_start_in_subtree[WG_SIZE];
// only relevant/populated for local id 0
index_type first_segment_start_in_interval = NO_SEG_BOUNDARY;
index_type first_segment_start_in_k_group, first_segment_start_in_subtree;
%endif
// {{{ set up local data for input_fetch_exprs if any of them are stenciled
<%
fetch_expr_offsets = {}
for name, arg_name, ife_offset in input_fetch_exprs:
fetch_expr_offsets.setdefault(arg_name, set()).add(ife_offset)
local_fetch_expr_args = set(
arg_name
for arg_name, ife_offsets in fetch_expr_offsets.iteritems()
if -1 in ife_offsets or len(ife_offsets) > 1)
%>
%for arg_name in local_fetch_expr_args:
LOCAL_MEM ${arg_ctypes[arg_name]} l_${arg_name}[WG_SIZE*K];
%endfor
// }}}
const index_type interval_begin = interval_size * GID_0;
const index_type interval_end = min(interval_begin + interval_size, N);
const index_type unit_size = K * WG_SIZE;
index_type unit_base = interval_begin;
%for is_tail in [False, True]:
%if not is_tail:
for(; unit_base + unit_size <= interval_end; unit_base += unit_size)
%else:
if (unit_base < interval_end)
%endif
{
// {{{ carry out input_fetch_exprs
// (if there are ones that need to be fetched into local)
%if local_fetch_expr_args:
for(index_type k = 0; k < K; k++)
{
const index_type offset = k*WG_SIZE + LID_0;
const index_type read_i = unit_base + offset;
%for arg_name in local_fetch_expr_args:
%if is_tail:
if (read_i < interval_end)
%endif
{
l_${arg_name}[offset] = ${arg_name}[read_i];
}
%endfor
}
local_barrier();
%endif
// }}}
// {{{ read a unit's worth of data from global
for(index_type k = 0; k < K; k++)
{
const index_type offset = k*WG_SIZE + LID_0;
const index_type read_i = unit_base + offset;
%if is_tail:
if (read_i < interval_end)
%endif
{
%for name, arg_name, ife_offset in input_fetch_exprs:
${arg_ctypes[arg_name]} ${name};
%if arg_name in local_fetch_expr_args:
if (offset + ${ife_offset} >= 0)
${name} = l_${arg_name}[offset + ${ife_offset}];
else if (read_i + ${ife_offset} >= 0)
${name} = ${arg_name}[read_i + ${ife_offset}];
/*
else
if out of bounds, name is left undefined */
%else:
// ${arg_name} gets fetched directly from global
${name} = ${arg_name}[read_i];
%endif
%endfor
scan_type scan_value = INPUT_EXPR(read_i);
const index_type o_mod_k = offset % K;
const index_type o_div_k = offset / K;
ldata[o_mod_k][offset / K] = scan_value;
%if is_segmented:
bool is_seg_start = IS_SEG_START(read_i, scan_value);
l_segment_start_flags[o_mod_k][o_div_k] = is_seg_start;
%endif
%if store_segment_start_flags:
g_segment_start_flags[read_i] = is_seg_start;
%endif
}
}
// }}}
// {{{ carry in from previous unit, if applicable
%if is_segmented:
local_barrier();
first_segment_start_in_k_group = NO_SEG_BOUNDARY;
if (l_segment_start_flags[0][LID_0])
first_segment_start_in_k_group = unit_base + K*LID_0;
%endif
if (LID_0 == 0 && unit_base != interval_begin
%if is_segmented:
&& !l_segment_start_flags[0][0]
%endif
)
{
ldata[0][0] = SCAN_EXPR(ldata[K][WG_SIZE - 1], ldata[0][0]);
}
// }}}
local_barrier();
// {{{ scan along k (sequentially in each work item)
scan_type sum = ldata[0][LID_0];
%if is_tail:
const index_type offset_end = interval_end - unit_base;
%endif
for(index_type k = 1; k < K; k++)
{
%if is_tail:
if (K * LID_0 + k < offset_end)
%endif
{
scan_type tmp = ldata[k][LID_0];
index_type seq_i = unit_base + K*LID_0 + k;
%if is_segmented:
if (l_segment_start_flags[k][LID_0])
{
first_segment_start_in_k_group = min(
first_segment_start_in_k_group,
seq_i);
sum = tmp;
}
else
%endif
sum = SCAN_EXPR(sum, tmp);
ldata[k][LID_0] = sum;
}
}
// }}}
// store carry in out-of-bounds (padding) array entry (index K) in the K direction
ldata[K][LID_0] = sum;
%if is_segmented:
l_first_segment_start_in_subtree[LID_0] = first_segment_start_in_k_group;
%endif
local_barrier();
// {{{ tree-based local parallel scan
// This tree-based scan works as follows:
// - Each work item adds the previous item to its current state
// - barrier
// - Each work item adds in the item from two positions to the left
// - barrier
// - Each work item adds in the item from four positions to the left
// ...
// At the end, each item has summed all prior items.
// across k groups, along local id
// (uses out-of-bounds k=K array entry for storage)
scan_type val = ldata[K][LID_0];
<% scan_offset = 1 %>
% while scan_offset <= wg_size:
// {{{ reads from local allowed, writes to local not allowed
if (LID_0 >= ${scan_offset})
{
scan_type tmp = ldata[K][LID_0 - ${scan_offset}];
% if is_tail:
if (K*LID_0 < offset_end)
% endif
{
%if is_segmented:
if (l_first_segment_start_in_subtree[LID_0] == NO_SEG_BOUNDARY)
val = SCAN_EXPR(tmp, val);
%else:
val = SCAN_EXPR(tmp, val);
%endif
}
%if is_segmented:
// Prepare for l_first_segment_start_in_subtree, below.
// Note that this update must take place *even* if we're
// out of bounds.
first_segment_start_in_subtree = min(
l_first_segment_start_in_subtree[LID_0],
l_first_segment_start_in_subtree[LID_0 - ${scan_offset}]);
%endif
}
%if is_segmented:
else
{
first_segment_start_in_subtree =
l_first_segment_start_in_subtree[LID_0];
}
%endif
// }}}
local_barrier();
// {{{ writes to local allowed, reads from local not allowed
ldata[K][LID_0] = val;
%if is_segmented:
l_first_segment_start_in_subtree[LID_0] =
first_segment_start_in_subtree;
%endif
// }}}
local_barrier();
%if 0:
if (LID_0 == 0)
{
printf("${scan_offset}: ");
for (int i = 0; i < WG_SIZE; ++i)
{
if (l_first_segment_start_in_subtree[i] == NO_SEG_BOUNDARY)
printf("- ");
else
printf("%d ", l_first_segment_start_in_subtree[i]);
}
printf("\n");
}
%endif
<% scan_offset *= 2 %>
% endwhile
// }}}
// {{{ update local values
if (LID_0 > 0)
{
sum = ldata[K][LID_0 - 1];
for(index_type k = 0; k < K; k++)
{
bool do_update = true;
%if is_tail:
do_update = K * LID_0 + k < offset_end;
%endif
%if is_segmented:
do_update = unit_base + K * LID_0 + k
< first_segment_start_in_k_group;
%endif
if (do_update)
{
scan_type tmp = ldata[k][LID_0];
ldata[k][LID_0] = SCAN_EXPR(sum, tmp);
}
}
}
%if is_segmented:
if (LID_0 == 0)
{
// update interval-wide first-seg variable from current unit
first_segment_start_in_interval = min(
first_segment_start_in_interval,
l_first_segment_start_in_subtree[WG_SIZE-1]);
}
%endif
// }}}
local_barrier();
// {{{ write data
for (index_type k = 0; k < K; k++)
{
const index_type offset = k*WG_SIZE + LID_0;
%if is_tail:
if (unit_base + offset < interval_end)
%endif
{
partial_scan_buffer[unit_base + offset] =
ldata[offset % K][offset / K];
}
}
// }}}
local_barrier();
}
% endfor
// write interval sum
if (LID_0 == 0)
{
%if is_first_level:
interval_results[GID_0] = partial_scan_buffer[interval_end - 1];
%endif
%if is_segmented and is_first_level:
g_first_segment_start_in_interval[GID_0] = first_segment_start_in_interval;
%endif
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment