Last active
May 19, 2017 21:35
-
-
Save bartvm/404f769b98bda7837dafe5fb730d17b0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ast | |
import collections | |
import inspect | |
import numbers | |
import textwrap | |
import numpy | |
PUSH = ast.Attribute(value=ast.Name(id='_stack', ctx=ast.Load()), | |
attr='push', ctx=ast.Load()) | |
POP = ast.Attribute(value=ast.Name(id='_stack', ctx=ast.Load()), | |
attr='pop', ctx=ast.Load()) | |
def parse_function(fn): | |
return ast.parse(textwrap.dedent(inspect.getsource(fn))) | |
class NodeReverse(object): | |
"""Generate a primal and adjoint for a given AST tree. | |
Notes | |
----- | |
In principle, this class simply walks the AST recursively and for each node | |
returns a new primal and an adjoint. | |
A limited amount of communication happens through the state of the | |
class. Assign statements set `current_target` so that the adjoint of the | |
right hand side knows what gradient to read. On the other hand, right | |
hand side expressions set `current_partials` to tell assignment | |
statements what variables the partials were written to. | |
""" | |
def visit(self, node): | |
method = 'visit_' + node.__class__.__name__ | |
visitor = getattr(self, method, self.generic_visit) | |
return visitor(node) | |
@staticmethod | |
def create_grad(node): | |
"""Given a variable, create variable name for the gradient. | |
WARNING: This returns an invalid node, with the `ctx` attribute | |
missing. It is assumed that this attribute is filled in later (e.g. | |
by the `replace` function). | |
""" | |
if not isinstance(node, ast.Name): | |
raise TypeError | |
return ast.Name(id='d' + node.id) | |
@staticmethod | |
def create_var(id_): | |
"""Method to create a named variable. Used for temporaries.""" | |
return ast.Name(id=id_, ctx=ast.Load()) | |
def visit_FunctionDef(self, node): | |
# TODO Change function signatures to receive stack | |
# TODO Change adjoint signature to take output of primal and its initial gradient | |
pass | |
def visit_statements(self, nodes): | |
"""Generate the adjoint of a series of statements.""" | |
primals, adjoints = [], collections.deque() | |
for node in nodes: | |
primal, adjoint = self.visit(node) | |
primals.extend(primal) | |
adjoints.extendleft(adjoint[::-1]) | |
return primals, list(adjoints) | |
def visit_For(self, node): | |
assert not node.orelse | |
primal_body, adjoint_body = self.visit_statements(node.body) | |
def primal_template(body, iter_, target, push): | |
i = 0 | |
for target in iter_: | |
i += 1 | |
body | |
push(i) | |
primal = replace(primal_template, body=primal_body, push=PUSH, | |
target=node.target, iter_=node.iter) | |
def adjoint_template(body, pop): | |
i = pop() | |
for _ in range(i): | |
body | |
adjoint = replace(adjoint_template, body=adjoint_body, pop=POP) | |
return primal, adjoint | |
def visit_BinOp(self, node): | |
adjoint_templates = {} | |
def adjoint_Mult_template(x, y, dx, dy, dz): | |
dx = dz * y | |
dy = dz * x | |
adjoint_templates[ast.Mult] = adjoint_Mult_template | |
def adjoint_Add_template(x, y, dx, dy, dz): | |
dx = dz | |
dy = dz | |
adjoint_templates[ast.Add] = adjoint_Add_template | |
def adjoint_Div_template(x, y, dx, dy, dz): | |
dx = dz / y | |
dy = -dz * x / y ** 2 | |
adjoint_templates[ast.Div] = adjoint_Div_template | |
op = type(node.op) | |
if op not in adjoint_templates: | |
raise ValueError("unknown binary operator") | |
self.current_partials = { | |
node.left: self.create_var('__dx'), | |
node.right: self.create_var('__dy') | |
} | |
return node, replace( | |
adjoint_templates[op], | |
x=node.left, y=node.right, | |
dx=self.current_partials[node.left], | |
dy=self.current_partials[node.right], | |
dz=self.create_grad(self.current_target)) | |
def visit_Assign(self, node): | |
if len(node.targets) != 1: | |
raise ValueError | |
if isinstance(node.targets[0], ast.Tuple): | |
if not isinstance(node.value, ast.Name): | |
raise ValueError("can only unpack variables") | |
# TODO Pack the gradients into a tuple | |
raise ValueError("no support for tuple assignments") | |
if not isinstance(node.targets[0], ast.Name): | |
raise ValueError("can only assign to names") | |
# Extract the target and store it in the state so that the | |
# right hand side templates can use it | |
target = node.targets[0] | |
self.current_target = target | |
primal_rhs, adjoint_rhs = self.visit(node.value) | |
# NOTE We simplify things here by EAFP. Ideally each variable that is | |
# pushed at any point should be set to None at the beginning | |
def primal_template(target, primal_rhs, push): | |
try: | |
push(target) | |
except NameError: | |
push(None) | |
target = primal_rhs | |
primal = replace(primal_template, target=target, | |
primal_rhs=primal_rhs, push=PUSH) | |
# NOTE For each partial gradient from the rhs we want to accumulate | |
# it into the existing gradient; this is the template for that | |
# NOTE EAFP approach again; gradients should be initialized beforehand | |
def accumulate_template(in_grad, partial_grad): | |
try: | |
in_grad = add_grad(in_grad, partial_grad) | |
except NameError: | |
in_grad = partial_grad | |
gradient_accumulation = [] | |
for partial in self.current_partials: | |
gradient_accumulation.extend(replace( | |
accumulate_template, in_grad=self.create_grad(partial), | |
partial_grad=self.current_partials[partial])) | |
# The final adjoint restores the input (pop), stores the partials | |
# in temporary variables, resets the gradient w.r.t. output, | |
# and finally updates the gradients | |
def adjoint_template(target, adjoint_rhs, target_grad, | |
gradient_accumulation, pop): | |
target = pop() | |
adjoint_rhs | |
target_grad = 0 | |
gradient_accumulation | |
adjoint = replace(adjoint_template, target=target, | |
adjoint_rhs=adjoint_rhs, | |
gradient_accumulation=gradient_accumulation, | |
target_grad=self.create_grad(target), pop=POP) | |
# Reset the state | |
self.current_target = None | |
self.current_partials = None | |
return primal, adjoint | |
def generic_visit(self, node): | |
raise ValueError("unknown node type") | |
class ReplaceTransformer(ast.NodeTransformer): | |
"""Replace variables with AST nodes""" | |
def __init__(self, replacements): | |
self.replacements = replacements | |
def visit_Name(self, node): | |
replacement_node = self.replacements.get(node.id, node) | |
# Use the replacement node in the same context as the placeholder | |
if isinstance(replacement_node, ast.AST) and \ | |
'ctx' in replacement_node._fields: | |
replacement_node.ctx = node.ctx | |
return replacement_node | |
def replace(fn, **replacements): | |
"""Replace placeholders in a Python template (quote). | |
One special thing happens: If a replacement node has a ctx attribute, it | |
is made to match the ctx attribute of the variable it is replacing. | |
Parameters | |
---------- | |
fn : function | |
A function used as a metaprogramming template. | |
replacements : dict | |
A mapping from the variable names of the function's arguments to (lists | |
of) AST nodes that these variables will be replaced with wherever they | |
appear in the function body. A replacement can be a list, in which case | |
it will be merged into the list of statements containing the node. | |
Returns | |
------- | |
body : list | |
A list of statements in the form of AST nodes. | |
""" | |
tree = parse_function(fn).body[0] | |
if replacements.keys() != set(arg.arg for arg in tree.args.args): | |
raise ValueError("too many or few replacements") | |
tree = ReplaceTransformer(replacements).visit(tree) | |
return tree.body | |
def add_grad(left, right): | |
"""Recursively add the gradient of e.g. tuples.""" | |
# If the gradient is undefined, then we simply return the rhs | |
# NOTE This is more efficient than initializing empty gradients and | |
# adding to them, since we could be adding to large matrix of zeros then | |
if left is None: | |
return right | |
assert right is not None | |
if type(left) != type(right): | |
raise TypeError("incompatible gradients") | |
if isinstance(left, (numpy.ndarray, numbers.Number)): | |
return left + right | |
if isinstance(left, tuple): | |
return tuple(lelem + relem for lelem, relem in zip(left, right)) | |
raise TypeError("unknown gradient type") | |
def f(x): | |
y = x * x | |
def g(x): | |
for i in range(10): | |
y = x * x | |
if __name__ == "__main__": | |
body = parse_function(g).body[0].body | |
primal, adjoint = NodeReverse().visit_statements(body) | |
import astor | |
print("PRIMAL") | |
print(astor.to_source(ast.Module(body=primal))) | |
print("ADJOINT") | |
print(astor.to_source(ast.Module(body=adjoint))) | |
# PRIMAL | |
# i = 0 | |
# for i in range(10): | |
# i += 1 | |
# try: | |
# _stack.push(y) | |
# except NameError: | |
# _stack.push(None) | |
# y = x * x | |
# _stack.push(i) | |
# ADJOINT | |
# i = _stack.pop() | |
# for _ in range(i): | |
# y = _stack.pop() | |
# __dx = dy * x | |
# __dy = dy * x | |
# dy = 0 | |
# try: | |
# dx = add_grad(dx, __dx) | |
# except NameError: | |
# dx = __dx | |
# try: | |
# dx = add_grad(dx, __dy) | |
# except NameError: | |
# dx = __dy |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
would also be curious to see how you specialize Call(), whereas BinOp has stereotyped args, Call grads will have variable args. So you'll have to break apart the args and kwargs coming from the primal Call node and feed them in somehow (maybe a place for *args and **kwargs to the quote)