Skip to content

Instantly share code, notes, and snippets.

@kaushikcfd
Created June 16, 2018 22:44
Show Gist options
  • Save kaushikcfd/c7ec13a67e730fbc8a63700543722ac3 to your computer and use it in GitHub Desktop.
Save kaushikcfd/c7ec13a67e730fbc8a63700543722ac3 to your computer and use it in GitHub Desktop.
import loopy as lp
from loopy.isl_helpers import simplify_via_aff
from pymbolic.primitives import CallWithKwargs
from loopy.kernel.function_interface import (get_kw_pos_association,
register_pymbolic_calls_to_knl_callables)
from loopy.symbolic import IdentityMapper
class DimChanger(IdentityMapper):
def __init__(self, caller_arg_dict, callee_arg_dict, callee_to_caller_args):
self.caller_arg_dict = caller_arg_dict
self.callee_arg_dict = callee_arg_dict
self.callee_to_caller_args = callee_to_caller_args
def map_subscript(self, expr):
callee_arg_dim_tags = self.callee_arg_dict[expr.aggregate.name].dim_tags
caller_arg_dim_tags = self.caller_arg_dict[
self.callee_to_caller_args[expr.aggregate.name]].dim_tags
flattened_index = sum(dim_tag.stride*idx for dim_tag, idx in
zip(callee_arg_dim_tags, expr.index_tuple))
new_indices = []
for dim_tag in caller_arg_dim_tags:
ind = flattened_index // dim_tag.stride
flattened_index -= (dim_tag.stride * ind)
new_indices.append(simplify_via_aff(ind))
return expr.aggregate.index(tuple(new_indices))
def match_caller_callee_argument_dimension(caller_knl, callee_fn):
"""
One must call this after registering the callee kernel into the caller
kernel.
"""
pymbolic_calls_to_new_callables = {}
for insn in caller_knl.instructions:
if not isinstance(insn, lp.CallInstruction) or (
insn.expression.function.name not in
caller_knl.scoped_functions):
continue
in_knl_callable = caller_knl.scoped_functions[
insn.expression.function.name]
if in_knl_callable.subkernel.name != callee_fn:
continue
# getting the caller callee arg association
parameters = insn.expression.parameters[:]
kw_parameters = {}
if isinstance(insn.expression, CallWithKwargs):
kw_parameters = insn.expression.kw_parameters
assignees = insn.assignees
parameters = [par.subscript.aggregate.name for par in parameters]
kw_to_pos, pos_to_kw = get_kw_pos_association(in_knl_callable.subkernel)
for i in range(len(parameters), len(parameters)+len(kw_parameters)):
parameters.append(kw_parameters[pos_to_kw[i]].name)
# inserting the assigness at the required positions.
assignee_write_count = -1
for i, arg in enumerate(in_knl_callable.subkernel.args):
if arg.direction == 'out':
assignee = assignees[-assignee_write_count-1]
parameters.insert(i, assignee.subscript.aggregate.name)
assignee_write_count -= 1
callee_to_caller_arg_map = dict(zip([arg.name for arg in
in_knl_callable.subkernel.args], parameters))
dim_changer = DimChanger(caller_knl.arg_dict,
in_knl_callable.subkernel.arg_dict, callee_to_caller_arg_map)
new_callee_insns = []
for callee_insn in in_knl_callable.subkernel.instructions:
if isinstance(callee_insn, lp.MultiAssignmentBase):
new_callee_insns.append(callee_insn.copy(expression=dim_changer(
callee_insn.expression),
assignee=dim_changer(callee_insn.assignee)))
new_subkernel = in_knl_callable.subkernel.copy(instructions=new_callee_insns)
new_in_knl_callable = in_knl_callable.copy(subkernel=new_subkernel)
pymbolic_calls_to_new_callables[insn.expression] = new_in_knl_callable
return register_pymbolic_calls_to_knl_callables(caller_knl,
pymbolic_calls_to_new_callables)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment