Created
March 26, 2020 01:13
-
-
Save eigenfoo/43282c4e69156647d7bb2505f1dbafb2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class FunctionToGenerator(ast.NodeTransformer): | |
""" | |
This subclass traverses the AST of the user-written, decorated, | |
model specification and transforms it into a generator for the | |
model. Subclassing in this way is the idiomatic way to transform | |
an AST. | |
Specifically: | |
1. Add `yield` keywords to all assignments | |
E.g. `x = tfd.Normal(0, 1)` -> `x = yield tfd.Normal(0, 1)` | |
2. Rename the model specification function to | |
`_pm_compiled_model_generator`. This is done out an abundance | |
of caution more than anything. | |
3. Remove the @Model decorator. Otherwise, we risk running into | |
an infinite recursion. | |
""" | |
def visit_Assign(self, node): | |
new_node = node | |
new_node.value = ast.Yield(value=new_node.value) | |
# Tie up loose ends in the AST. | |
ast.copy_location(new_node, node) | |
ast.fix_missing_locations(new_node) | |
self.generic_visit(node) | |
return new_node | |
def visit_FunctionDef(self, node): | |
new_node = node | |
new_node.name = "_pm_compiled_model_generator" | |
new_node.decorator_list = [] | |
# Tie up loose ends in the AST. | |
ast.copy_location(new_node, node) | |
ast.fix_missing_locations(new_node) | |
self.generic_visit(node) | |
return new_node |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment