Created
May 16, 2017 20:42
-
-
Save bartvm/c4d0eb4906c091d3f006cf389f2d8a6e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from ast import parse, NodeTransformer | |
from inspect import getsource | |
from textwrap import dedent | |
from collections import namedtuple | |
# We fill in the templates in two ways: | |
# 1. We insert strings | |
# 2. We insert a placeholder variable name, which is parsed into the AST before | |
# being replaced by a subtree | |
# All the stack operations are templated instead of hard coded | |
PUSH = 'list.append' | |
POP = 'list.pop' | |
# Gradient specifications are stateless, so use a named tuple | |
GradientSpecification = namedtuple('GradientSpecification', | |
['primal', 'adjoint']) | |
# Create a bunch of unique placeholder names | |
def create_placeholders(**kwargs): | |
return {name: 'placeholder_{}'.format(name) for name in kwargs} | |
class ReplacePlaceholder(NodeTransformer): | |
def __init__(self, **placeholders): | |
self.placeholders = placeholders | |
def visit_Name(self, node): | |
return self.placeholders.get(node.id, node) | |
for_grad = GradientSpecification( | |
primal=""" | |
{create_stacks} | |
for {target} in {iter}: | |
{push}({stacks[0]}, {target}) | |
{body} | |
""", | |
adjoint=""" | |
while {stacks[0]}: | |
{target} = {pop}({stacks[0]}) | |
{reverse_body} | |
""") | |
def reverse_for(tree): | |
# There are AST nodes we want to directly insert into the template | |
replacements = { | |
'target': tree.target, | |
'iter': tree.iter, | |
'body': tree.body | |
} | |
placeholders = create_placeholders(**replacements) | |
# Fill in the template with (a) strings, (b) some global operations, and | |
# (c) placeholders | |
primal = dedent(for_grad.primal).format( | |
create_stacks='stack = []', stacks=['stack'], | |
push=PUSH, pop=POP, | |
**placeholders) | |
# Now parse the template, and make the replacements | |
primal_tree = parse(primal) | |
primal_tree = ReplacePlaceholder( | |
**{placeholders[name]: replacements[name] | |
for name in replacements}).visit(primal_tree) | |
# Return the final primal | |
return primal_tree | |
# Test | |
def f(x): | |
for i, j in zip(range(10), range(3)): | |
x = i * x | |
return x | |
full_tree = parse(getsource(f)) | |
for_tree = full_tree.body[0].body[0] | |
print(reverse_for(for_tree)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment