Skip to content

Instantly share code, notes, and snippets.

@nvbn
Created August 10, 2016 14:20
Show Gist options
  • Save nvbn/8d8b242ae88c97d1746e3b8b8ebbc257 to your computer and use it in GitHub Desktop.
Save nvbn/8d8b242ae88c97d1746e3b8b8ebbc257 to your computer and use it in GitHub Desktop.
Partial application and piping with AST transformation
import ast
class EllipsisPartialTransform(ast.NodeTransformer):
def __init__(self):
self._counter = 0
def _get_arg_name(self):
"""Return unique argument name for lambda."""
try:
return '__ellipsis_partial_arg_{}'.format(self._counter)
finally:
self._counter += 1
def _is_ellipsis(self, arg):
return isinstance(arg, ast.Ellipsis)
def _replace_argument(self, node, arg_name):
"""Replace ellipsis with argument."""
replacement = ast.Name(id=arg_name,
ctx=ast.Load())
node.args = [replacement if self._is_ellipsis(arg) else arg
for arg in node.args]
return node
def _wrap_in_lambda(self, node):
"""Wrap call in lambda and replace ellipsis with argument."""
arg_name = self._get_arg_name()
node = self._replace_argument(node, arg_name)
return ast.Lambda(
args=ast.arguments(args=[ast.arg(arg=arg_name, annotation=None)],
vararg=None, kwonlyargs=[], kw_defaults=[],
kwarg=None, defaults=[]),
body=node)
def visit_Call(self, node):
if any(self._is_ellipsis(arg) for arg in node.args):
node = self._wrap_in_lambda(node)
node = ast.fix_missing_locations(node)
return self.generic_visit(node)
class MatMulPipeTransformation(ast.NodeTransformer):
def _replace_with_call(self, node):
"""Call right part of operation with left part as an argument."""
return ast.Call(func=node.right, args=[node.left], keywords=[])
def visit_BinOp(self, node):
if isinstance(node.op, ast.MatMult):
node = self._replace_with_call(node)
node = ast.fix_missing_locations(node)
return self.generic_visit(node)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment