Created
October 7, 2011 01:24
-
-
Save jeffkistler/1269212 to your computer and use it in GitHub Desktop.
Control flow graph building visitor for JavaScript.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
An abstract syntax tree visitor for building control flow graphs for ECMAScript | |
programs and functions. | |
""" | |
from bigrig.visitor import NodeVisitor | |
from bigrig.node import Node | |
from bigrig import ast | |
from .graph import Digraph | |
# | |
# Edge Labels | |
# | |
UNCONDITIONAL = intern('UNCONDITIONAL') | |
TRUE = intern('TRUE') | |
FALSE = intern('FALSE') | |
THROW = intern('THROW') # Possible exception its own label? | |
BREAK = intern('BREAK') | |
CONTINUE = intern('CONTINUE') | |
RETURN = intern('RETURN') | |
class AncestryTrackingVisitor(NodeVisitor): | |
""" | |
Track parent nodes. | |
""" | |
def __init__(self, parent=None): | |
self.parents = parent and [parent] or [] | |
self.parent_map = {} | |
super(AncestryTrackingVisitor, self).__init__() | |
def get_parent(self, node): | |
assert node in self.parent_map | |
return self.parent_map[node] | |
def push(self, node): | |
parent = None | |
if self.parents: | |
parent = self.parents[-1] | |
self.parent_map[node] = parent | |
self.parents.append(node) | |
def pop(self): | |
return self.parents.pop() | |
def visit(self, node): | |
is_node = isinstance(node, Node) | |
if is_node: | |
self.push(node) | |
super(AncestryTrackingVisitor, self).visit(node) | |
if is_node: | |
self.pop() | |
LEAF_TYPES = ( | |
ast.ExpressionStatement, ast.ContinueStatement, ast.BreakStatement, | |
ast.ReturnStatement, ast.Throw, ast.WithStatement, ast.EmptyStatement, | |
) | |
SCOPE_NODES = ( | |
ast.FunctionDeclaration, ast.FunctionExpression, ast.Program | |
) | |
MAY_THROW = ( | |
ast.DotProperty, ast.BracketProperty, ast.CallExpression, | |
ast.NewExpression, ast.UnaryOperation, ast.DeleteOperation, | |
ast.PrefixCountOperation, ast.PostfixCountOperation, | |
ast.BinaryOperation, ast.Assignment | |
) | |
def may_throw(node): | |
""" | |
Returns ``True`` if the expression may throw an exception. | |
""" | |
if isinstance(node, MAY_THROW): | |
return True | |
else: | |
return any(may_throw(child) for child in node.iter_children()) | |
return False | |
class TryBlock(ast.Block): | |
abstract = False | |
class CatchBlock(ast.Block): | |
abstract = False | |
class FinallyBlock(ast.Block): | |
abstract = False | |
class ControlFlowGraphVisitor(AncestryTrackingVisitor): | |
""" | |
Build a control flow graph from an abstract syntax tree. | |
""" | |
def __init__(self, parent=None, graph=None): | |
self.graph = graph is not None and graph or Digraph() | |
self.successor_map = {} | |
super(ControlFlowGraphVisitor, self).__init__(parent=parent) | |
def add_edge(self, from_node, to_node, label=UNCONDITIONAL): | |
""" | |
Add a directed edge to the control flow graph. | |
""" | |
assert from_node is not None | |
assert to_node is not None | |
self.graph.add_edge(from_node, to_node, label=label) | |
def get_entry(self, statement): | |
""" | |
Get the entrypoint of a given statement. | |
""" | |
if isinstance(statement, LEAF_TYPES): | |
return statement | |
elif isinstance(statement, ast.IfStatement): | |
return statement.condition | |
elif isinstance(statement, ast.Block): | |
if statement.statements: | |
return self.get_entry(statement.statements) | |
else: | |
return self.get_successor(statement) | |
elif isinstance(statement, ast.DoWhileStatement): | |
return self.get_entry(statement.body) | |
elif isinstance(statement, ast.ForStatement): | |
if statement.initialize: | |
return statement.initialize | |
elif statement.condition: | |
return statement.condition | |
else: | |
return self.get_entry(statement.body) | |
elif isinstance(statement, ast.ForInStatement): # ???? | |
return statement.each | |
elif isinstance(statement, ast.LabelledStatement): | |
return self.get_entry(statement.statement) | |
elif isinstance(statement, ast.TryStatement): | |
return self.get_entry(statement.try_block) | |
elif isinstance(statement, ast.SwitchStatement): # ???? | |
return self.get_entry(statement.cases) | |
elif isinstance(statement, ast.CaseClause): | |
return self.get_entry(statement.statements) | |
elif isinstance(statement, list): | |
if not statement: | |
return None | |
return self.get_entry(statement[0]) | |
assert statement is not None | |
return None | |
def get_successor(self, node): | |
""" | |
Get the successor statement for a given statement node. | |
""" | |
# If the node isn't in the successor map, we're at the end of | |
# a list of statements, or in a single statement structure, | |
# therefore we need to look at the parent to figure out what | |
# should come next | |
if node not in self.successor_map: | |
parent = self.get_parent(node) | |
# Handle exception handling blocks first | |
if isinstance(parent, TryBlock): | |
try_statement = self.get_parent(parent) | |
if try_statement.catch_block: | |
return self.get_entry(try_statement.catch_block) | |
else: | |
return self.get_entry(try_statement.finally_block) | |
elif isinstance(parent, CatchBlock): | |
try_statement = self.get_parent(parent) | |
if try_statement.finally_block: | |
return self.get_entry(try_statement.finally_block) | |
else: | |
return self.get_successor(try_statement) | |
elif isinstance(parent, FinallyBlock): | |
parent = self.get_parent(parent) | |
return self.get_successor(parent) | |
# If we have a block, its parent is what we're interested in | |
if isinstance(parent, ast.Block): | |
print parent, parent.__class__ | |
parent = self.get_parent(parent) | |
print parent, parent.__class__ | |
if isinstance(parent, ast.Block): | |
return self.get_successor(parent) | |
# Now we can figure out what to do with this node based on type | |
if isinstance(parent, ast.ForStatement): | |
if parent.next: | |
return parent.next | |
elif parent.condition: | |
return parent.condition | |
else: | |
return self.get_successor(parent) | |
elif isinstance(parent, ast.ForInStatement): | |
return parent.each | |
elif isinstance(parent, ast.WhileStatement): | |
print parent, node, getattr(node, 'expression', None), parent.condition | |
return parent.condition | |
elif isinstance(parent, ast.DoWhileStatement): | |
return parent.condition | |
elif isinstance(parent, ast.IfStatement): | |
return self.get_successor(parent) | |
elif isinstance(parent, SCOPE_NODES): | |
# Terminal. | |
# Perhaps we should have a "connect_successor" method that | |
# makes an IMPLICIT_RETURN edge? | |
return parent | |
# elif isinstance(parent, ast.CaseClause): | |
else: | |
raise Exception(parent) | |
else: | |
return self.successor_map[node] | |
return None | |
# | |
# Entrypoint | |
# | |
def build(self, node): | |
assert isinstance(node, SCOPE_NODES) | |
self.parents.append(node) | |
self.visit_node(node) | |
return self.graph | |
# | |
# Statement Blocks | |
# | |
def visit_statement_list(self, statements): | |
""" | |
Visit a list of statements, tracking successor nodes along the way. | |
""" | |
last_index = len(statements) - 1 | |
for i, statement in enumerate(statements): | |
if i < last_index: | |
self.successor_map[statement] = self.get_entry(statements[i + 1]) | |
self.visit(statement) | |
def visit_Block(self, node): | |
self.visit_statement_list(node.statements) | |
def visit_ExpressionStatement(self, node): | |
if may_throw(node.expression): | |
self.connect_node_to_exception_handler(node) | |
self.add_edge(node, self.get_successor(node)) | |
def visit_WithStatement(self, node): | |
self.generic_visit(node) | |
self.add_edge(node, self.get_entry(node.statement)) | |
def visit_EmptyStatement(self, node): | |
self.generic_visit(node) | |
self.add_edge(node, self.get_successor(node)) | |
# | |
# Jumps | |
# | |
def get_labelled_jump_target(self, node, target_types): | |
""" | |
Find the first parent with a matching label and type. | |
""" | |
target_label = node.label | |
for parent in reversed(self.parents): | |
if isinstance(parent, ast.LabelledStatement) and\ | |
parent.label == target_label and\ | |
isinstance(parent, target_types): # Check statement type? | |
return parent | |
def get_jump_target(self, node, target_types): | |
""" | |
Find the first parent that is an instance of one of the target types. | |
""" | |
for parent in reversed(self.parents): | |
if isinstance(parent, target_types): | |
return parent | |
return parent | |
def visit_Break(self, node): | |
BREAK_TARGETS = ( | |
ast.ForStatement, ast.ForInStatement, ast.DoWhile, | |
ast.WhileStatement, ast.SwitchStatement | |
) | |
LABELLED_BREAK_TARGETS = BREAK_TARGETS + ( | |
ast.Block, ast.IfStatement, ast.TryStatement | |
) | |
if node.label: | |
target = self.get_labelled_jump_target(node, LABELLED_BREAK_TARGETS) | |
else: | |
# If unlabelled, we need to find the next sibling of the target | |
# statement? | |
target = self.get_jump_target(node, BREAK_TARGETS) | |
successor = self.get_successor(target) | |
if successor is not None: | |
self.add_edge(node, successor, BREAK) | |
def visit_Continue(self, node): | |
CONTINUE_TARGETS = ( | |
ast.ForStatement, ast.ForInStatement, ast.DoWhile, ast.WhileStatement | |
) | |
if node.label: | |
target = self.get_labelled_jump_target(node, CONTINUE_TARGETS) | |
successor = target.statement | |
else: | |
target = self.get_jump_target(node, CONTINUE_TARGETS) | |
if isinstance(target, ast.ForStatement): | |
successor = target.next | |
elif isinstance(target, ast.ForInStatement): | |
successor = target.each | |
else: | |
successor = target.condition | |
if successor is not None: | |
self.add_edge(node, successor, CONTINUE) | |
def visit_Return(self, node): | |
RETURN_TARGETS = SCOPE_NODES | |
expression = node.expression | |
if may_throw(expression): | |
self.connect_node_to_exception_handler(expression) | |
target = self.get_jump_target(node, RETURN_TARGETS) | |
self.add_edge(node, target, RETURN) | |
# | |
# Conditional | |
# | |
def visit_IfStatement(self, node): | |
self.generic_visit(node) | |
condition = node.condition | |
if may_throw(condition): | |
self.connect_node_to_exception_handler(condition) | |
then_statement = node.then_statement | |
else_statement = node.else_statement | |
self.add_edge(condition, self.get_entry(then_statement), label=TRUE) | |
if else_statement: | |
false_successor = self.get_entry(else_statement) | |
else: | |
false_successor = self.get_successor(node) | |
self.add_edge(condition, false_successor, label=FALSE) | |
def visit_SwitchStatement(self, node): | |
# Connect object to each case with TRUE edge? | |
# If default, connect with FALSE edge as well?? | |
# How to handle fallthrough? | |
self.generic_visit(node) | |
# | |
# Looping | |
# | |
def visit_ForStatement(self, node): | |
self.generic_visit(node) | |
initialize = node.initialize | |
condition = node.condition | |
next = node.next | |
for expression in (initialize, condition, next): | |
if may_throw(expression): | |
self.connect_node_to_exception_handler(expression) | |
self.add_edge(initialize, condition) | |
self.add_edge(next, condition) | |
self.add_edge(condition, self.get_entry(node.body), TRUE) | |
self.add_edge(condition, self.get_successor(node), FALSE) | |
def visit_ForInStatement(self, node): | |
self.generic_visit(node) | |
each = node.each | |
enumerable = node.enumerable | |
for expression in (each, enumerable): | |
if may_throw(expression): | |
self.connect_node_to_exception_handler(expression) | |
self.add_edge(enumerable, self.get_entry(node.body), TRUE) | |
self.add_edge(enumerable, self.get_successor(node), FALSE) | |
def visit_WhileStatement(self, node): | |
self.generic_visit(node) | |
condition = node.condition | |
if may_throw(condition): | |
self.connect_node_to_exception_handler(condition) | |
self.add_edge(node.condition, self.get_entry(node.body), TRUE) | |
self.add_edge(node.condition, self.get_successor(node), FALSE) | |
def visit_DoWhileStatement(self, node): | |
self.generic_visit(node) | |
condition = node.condition | |
if may_throw(condition): | |
self.connect_node_to_exception_handler(condition) | |
self.add_edge(condition, self.get_entry(node.body), TRUE) | |
self.add_edge(condition, self.get_successor(node), FALSE) | |
# | |
# Exceptions | |
# | |
def get_exception_handler(self, node): | |
exception_context = ( | |
ast.FunctionExpression, ast.FunctionDeclaration, | |
ast.Program, TryBlock, CatchBlock | |
) | |
parent = self.get_parent(node) | |
while not isinstance(parent, exception_context): | |
parent = self.get_parent(parent) | |
assert parent is not None | |
if isinstance(parent, TryBlock): | |
try_statement = self.get_parent(parent) | |
if try_statement.catch_block: | |
return self.get_entry(try_statement.catch_block) | |
else: | |
return self.get_entry(try_statement.finally_block) | |
elif isinstance(parent, CatchBlock): | |
try_statement = self.get_parent(parent) | |
if try_statement.finally_block: | |
return self.get_entry(try_statement.finally_block) | |
else: | |
return self.get_exception_handler(try_statement) | |
return parent | |
def connect_node_to_exception_handler(self, from_node): | |
handler = self.get_exception_handler(from_node) | |
self.add_edge(from_node, handler, THROW) | |
def visit_Throw(self, node): | |
self.connect_node_to_exception_handler(node) | |
def visit_TryStatement(self, node): | |
""" | |
Make sure that the catch and finally blocks get wired up appropriately. | |
""" | |
# This will probably require synthetic blocks to be created, | |
# so connect_exception_handler will need to be reworked to respect | |
# that business | |
try_block = node.try_block | |
catch_block = node.catch_block | |
finally_block = node.finally_block | |
# We create some synthetic blocks here to make successor computation | |
# simpler | |
if try_block: | |
try_block = TryBlock(try_block.statements) | |
self.push(try_block) | |
self.visit_Block(try_block) | |
self.pop() | |
if catch_block: | |
catch_block = CatchBlock(catch_block.statements) | |
self.push(catch_block) | |
self.visit_Block(catch_block) | |
self.pop() | |
if finally_block: | |
finally_block = FinallyBlock(catch_block.statements) | |
self.push(finally_block) | |
self.visit_Block(finally_block) | |
self.pop() | |
test_cases = [ | |
'function x() { print(arguments); }', # Empty | |
'if (x) { print(1); } else { print(0); }', | |
'if (x) print(1); else print(0);', | |
'if (x) { print(1); }', | |
'if (x) ; else print(0);', | |
'while (x) { x--; }', | |
'do { x--; } while (x)', | |
'for (var x=0; x<10; x++) { print(x); }', | |
'for (var x in enumerable) { print(x); }', | |
'try { call(); } catch(e) { print(e); }', | |
'try { call(); } finally { print("Error"); }', | |
'try { call(); } catch(e) { print(e); } finally { print("Error"); }', | |
'while (x) { if (x < 0) break; x--; }', | |
] | |
def to_dot(graph, name=None): | |
import pydot | |
from .code_consumer import print_string | |
name_map = {} | |
def get_name(node): | |
if not node in name_map: | |
if isinstance(node, (ast.ExpressionStatement, ast.Expression)): | |
label = print_string(node) | |
else: | |
label = node.__class__.__name__ | |
if hasattr(node, 'locator') and node.locator is not None: | |
name = '%d, %d: %s' % (node.locator.line, node.locator.column, label) | |
else: | |
name = '%d: %s' % (id(node), label) | |
name_map[node] = name | |
return name_map[node] | |
if name: | |
dotgraph = pydot.Dot(name, graph_type='digraph', strict=True) | |
else: | |
dotgraph = pydot.Dot(graph_type='digraph', strict=True) | |
for from_node, to_node, attributes in graph.edges_iter(data=True): | |
# Get node names from nodes | |
from_name = get_name(from_node) | |
to_name = get_name(to_node) | |
dotedge = pydot.Edge(from_name, to_name, **attributes) | |
dotgraph.add_edge(dotedge) | |
return dotgraph.to_string() | |
def print_to_dotfile(graph, filename, name=None): | |
with open(filename, 'wb') as fd: | |
fd.write(to_dot(graph)) | |
def print_cfg_for_program(source_string, filename): | |
from .locator_parser import parse_string | |
ast = parse_string(source_string) | |
v = ControlFlowGraphVisitor() | |
cfg = v.build(ast) | |
print_to_dotfile(cfg, filename) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
that should be the
switch
expression, I guess, maybestatement.expression
or whatever it's called in the AST.I'm interested to do something like this for UglifyJS. Still digging through the literature..