Skip to content

Instantly share code, notes, and snippets.

@bartvm
Created May 16, 2017 20:42
Show Gist options
  • Save bartvm/c4d0eb4906c091d3f006cf389f2d8a6e to your computer and use it in GitHub Desktop.
Save bartvm/c4d0eb4906c091d3f006cf389f2d8a6e to your computer and use it in GitHub Desktop.
from ast import parse, NodeTransformer
from inspect import getsource
from textwrap import dedent
from collections import namedtuple
# We fill in the templates in two ways:
# 1. We insert strings
# 2. We insert a placeholder variable name, which is parsed into the AST before
# being replaced by a subtree
# All the stack operations are templated instead of hard coded
PUSH = 'list.append'
POP = 'list.pop'
# Gradient specifications are stateless, so use a named tuple
GradientSpecification = namedtuple('GradientSpecification',
['primal', 'adjoint'])
# Create a bunch of unique placeholder names
def create_placeholders(**kwargs):
return {name: 'placeholder_{}'.format(name) for name in kwargs}
class ReplacePlaceholder(NodeTransformer):
def __init__(self, **placeholders):
self.placeholders = placeholders
def visit_Name(self, node):
return self.placeholders.get(node.id, node)
for_grad = GradientSpecification(
primal="""
{create_stacks}
for {target} in {iter}:
{push}({stacks[0]}, {target})
{body}
""",
adjoint="""
while {stacks[0]}:
{target} = {pop}({stacks[0]})
{reverse_body}
""")
def reverse_for(tree):
# There are AST nodes we want to directly insert into the template
replacements = {
'target': tree.target,
'iter': tree.iter,
'body': tree.body
}
placeholders = create_placeholders(**replacements)
# Fill in the template with (a) strings, (b) some global operations, and
# (c) placeholders
primal = dedent(for_grad.primal).format(
create_stacks='stack = []', stacks=['stack'],
push=PUSH, pop=POP,
**placeholders)
# Now parse the template, and make the replacements
primal_tree = parse(primal)
primal_tree = ReplacePlaceholder(
**{placeholders[name]: replacements[name]
for name in replacements}).visit(primal_tree)
# Return the final primal
return primal_tree
# Test
def f(x):
for i, j in zip(range(10), range(3)):
x = i * x
return x
full_tree = parse(getsource(f))
for_tree = full_tree.body[0].body[0]
print(reverse_for(for_tree))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment