Created
          June 16, 2018 22:44 
        
      - 
      
- 
        Save kaushikcfd/c7ec13a67e730fbc8a63700543722ac3 to your computer and use it in GitHub Desktop. 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | import loopy as lp | |
| from loopy.isl_helpers import simplify_via_aff | |
| from pymbolic.primitives import CallWithKwargs | |
| from loopy.kernel.function_interface import (get_kw_pos_association, | |
| register_pymbolic_calls_to_knl_callables) | |
| from loopy.symbolic import IdentityMapper | |
| class DimChanger(IdentityMapper): | |
| def __init__(self, caller_arg_dict, callee_arg_dict, callee_to_caller_args): | |
| self.caller_arg_dict = caller_arg_dict | |
| self.callee_arg_dict = callee_arg_dict | |
| self.callee_to_caller_args = callee_to_caller_args | |
| def map_subscript(self, expr): | |
| callee_arg_dim_tags = self.callee_arg_dict[expr.aggregate.name].dim_tags | |
| caller_arg_dim_tags = self.caller_arg_dict[ | |
| self.callee_to_caller_args[expr.aggregate.name]].dim_tags | |
| flattened_index = sum(dim_tag.stride*idx for dim_tag, idx in | |
| zip(callee_arg_dim_tags, expr.index_tuple)) | |
| new_indices = [] | |
| for dim_tag in caller_arg_dim_tags: | |
| ind = flattened_index // dim_tag.stride | |
| flattened_index -= (dim_tag.stride * ind) | |
| new_indices.append(simplify_via_aff(ind)) | |
| return expr.aggregate.index(tuple(new_indices)) | |
| def match_caller_callee_argument_dimension(caller_knl, callee_fn): | |
| """ | |
| One must call this after registering the callee kernel into the caller | |
| kernel. | |
| """ | |
| pymbolic_calls_to_new_callables = {} | |
| for insn in caller_knl.instructions: | |
| if not isinstance(insn, lp.CallInstruction) or ( | |
| insn.expression.function.name not in | |
| caller_knl.scoped_functions): | |
| continue | |
| in_knl_callable = caller_knl.scoped_functions[ | |
| insn.expression.function.name] | |
| if in_knl_callable.subkernel.name != callee_fn: | |
| continue | |
| # getting the caller callee arg association | |
| parameters = insn.expression.parameters[:] | |
| kw_parameters = {} | |
| if isinstance(insn.expression, CallWithKwargs): | |
| kw_parameters = insn.expression.kw_parameters | |
| assignees = insn.assignees | |
| parameters = [par.subscript.aggregate.name for par in parameters] | |
| kw_to_pos, pos_to_kw = get_kw_pos_association(in_knl_callable.subkernel) | |
| for i in range(len(parameters), len(parameters)+len(kw_parameters)): | |
| parameters.append(kw_parameters[pos_to_kw[i]].name) | |
| # inserting the assigness at the required positions. | |
| assignee_write_count = -1 | |
| for i, arg in enumerate(in_knl_callable.subkernel.args): | |
| if arg.direction == 'out': | |
| assignee = assignees[-assignee_write_count-1] | |
| parameters.insert(i, assignee.subscript.aggregate.name) | |
| assignee_write_count -= 1 | |
| callee_to_caller_arg_map = dict(zip([arg.name for arg in | |
| in_knl_callable.subkernel.args], parameters)) | |
| dim_changer = DimChanger(caller_knl.arg_dict, | |
| in_knl_callable.subkernel.arg_dict, callee_to_caller_arg_map) | |
| new_callee_insns = [] | |
| for callee_insn in in_knl_callable.subkernel.instructions: | |
| if isinstance(callee_insn, lp.MultiAssignmentBase): | |
| new_callee_insns.append(callee_insn.copy(expression=dim_changer( | |
| callee_insn.expression), | |
| assignee=dim_changer(callee_insn.assignee))) | |
| new_subkernel = in_knl_callable.subkernel.copy(instructions=new_callee_insns) | |
| new_in_knl_callable = in_knl_callable.copy(subkernel=new_subkernel) | |
| pymbolic_calls_to_new_callables[insn.expression] = new_in_knl_callable | |
| return register_pymbolic_calls_to_knl_callables(caller_knl, | |
| pymbolic_calls_to_new_callables) | 
  
    Sign up for free
    to join this conversation on GitHub.
    Already have an account?
    Sign in to comment