Skip to content

Instantly share code, notes, and snippets.

@jeffkistler
Created October 7, 2011 01:24
Show Gist options
  • Save jeffkistler/1269212 to your computer and use it in GitHub Desktop.
Save jeffkistler/1269212 to your computer and use it in GitHub Desktop.
Control flow graph building visitor for JavaScript.
"""
An abstract syntax tree visitor for building control flow graphs for ECMAScript
programs and functions.
"""
from bigrig.visitor import NodeVisitor
from bigrig.node import Node
from bigrig import ast
from .graph import Digraph
#
# Edge Labels
#
UNCONDITIONAL = intern('UNCONDITIONAL')
TRUE = intern('TRUE')
FALSE = intern('FALSE')
THROW = intern('THROW') # Possible exception its own label?
BREAK = intern('BREAK')
CONTINUE = intern('CONTINUE')
RETURN = intern('RETURN')
class AncestryTrackingVisitor(NodeVisitor):
"""
Track parent nodes.
"""
def __init__(self, parent=None):
self.parents = parent and [parent] or []
self.parent_map = {}
super(AncestryTrackingVisitor, self).__init__()
def get_parent(self, node):
assert node in self.parent_map
return self.parent_map[node]
def push(self, node):
parent = None
if self.parents:
parent = self.parents[-1]
self.parent_map[node] = parent
self.parents.append(node)
def pop(self):
return self.parents.pop()
def visit(self, node):
is_node = isinstance(node, Node)
if is_node:
self.push(node)
super(AncestryTrackingVisitor, self).visit(node)
if is_node:
self.pop()
LEAF_TYPES = (
ast.ExpressionStatement, ast.ContinueStatement, ast.BreakStatement,
ast.ReturnStatement, ast.Throw, ast.WithStatement, ast.EmptyStatement,
)
SCOPE_NODES = (
ast.FunctionDeclaration, ast.FunctionExpression, ast.Program
)
MAY_THROW = (
ast.DotProperty, ast.BracketProperty, ast.CallExpression,
ast.NewExpression, ast.UnaryOperation, ast.DeleteOperation,
ast.PrefixCountOperation, ast.PostfixCountOperation,
ast.BinaryOperation, ast.Assignment
)
def may_throw(node):
"""
Returns ``True`` if the expression may throw an exception.
"""
if isinstance(node, MAY_THROW):
return True
else:
return any(may_throw(child) for child in node.iter_children())
return False
class TryBlock(ast.Block):
abstract = False
class CatchBlock(ast.Block):
abstract = False
class FinallyBlock(ast.Block):
abstract = False
class ControlFlowGraphVisitor(AncestryTrackingVisitor):
"""
Build a control flow graph from an abstract syntax tree.
"""
def __init__(self, parent=None, graph=None):
self.graph = graph is not None and graph or Digraph()
self.successor_map = {}
super(ControlFlowGraphVisitor, self).__init__(parent=parent)
def add_edge(self, from_node, to_node, label=UNCONDITIONAL):
"""
Add a directed edge to the control flow graph.
"""
assert from_node is not None
assert to_node is not None
self.graph.add_edge(from_node, to_node, label=label)
def get_entry(self, statement):
"""
Get the entrypoint of a given statement.
"""
if isinstance(statement, LEAF_TYPES):
return statement
elif isinstance(statement, ast.IfStatement):
return statement.condition
elif isinstance(statement, ast.Block):
if statement.statements:
return self.get_entry(statement.statements)
else:
return self.get_successor(statement)
elif isinstance(statement, ast.DoWhileStatement):
return self.get_entry(statement.body)
elif isinstance(statement, ast.ForStatement):
if statement.initialize:
return statement.initialize
elif statement.condition:
return statement.condition
else:
return self.get_entry(statement.body)
elif isinstance(statement, ast.ForInStatement): # ????
return statement.each
elif isinstance(statement, ast.LabelledStatement):
return self.get_entry(statement.statement)
elif isinstance(statement, ast.TryStatement):
return self.get_entry(statement.try_block)
elif isinstance(statement, ast.SwitchStatement): # ????
return self.get_entry(statement.cases)
elif isinstance(statement, ast.CaseClause):
return self.get_entry(statement.statements)
elif isinstance(statement, list):
if not statement:
return None
return self.get_entry(statement[0])
assert statement is not None
return None
def get_successor(self, node):
"""
Get the successor statement for a given statement node.
"""
# If the node isn't in the successor map, we're at the end of
# a list of statements, or in a single statement structure,
# therefore we need to look at the parent to figure out what
# should come next
if node not in self.successor_map:
parent = self.get_parent(node)
# Handle exception handling blocks first
if isinstance(parent, TryBlock):
try_statement = self.get_parent(parent)
if try_statement.catch_block:
return self.get_entry(try_statement.catch_block)
else:
return self.get_entry(try_statement.finally_block)
elif isinstance(parent, CatchBlock):
try_statement = self.get_parent(parent)
if try_statement.finally_block:
return self.get_entry(try_statement.finally_block)
else:
return self.get_successor(try_statement)
elif isinstance(parent, FinallyBlock):
parent = self.get_parent(parent)
return self.get_successor(parent)
# If we have a block, its parent is what we're interested in
if isinstance(parent, ast.Block):
print parent, parent.__class__
parent = self.get_parent(parent)
print parent, parent.__class__
if isinstance(parent, ast.Block):
return self.get_successor(parent)
# Now we can figure out what to do with this node based on type
if isinstance(parent, ast.ForStatement):
if parent.next:
return parent.next
elif parent.condition:
return parent.condition
else:
return self.get_successor(parent)
elif isinstance(parent, ast.ForInStatement):
return parent.each
elif isinstance(parent, ast.WhileStatement):
print parent, node, getattr(node, 'expression', None), parent.condition
return parent.condition
elif isinstance(parent, ast.DoWhileStatement):
return parent.condition
elif isinstance(parent, ast.IfStatement):
return self.get_successor(parent)
elif isinstance(parent, SCOPE_NODES):
# Terminal.
# Perhaps we should have a "connect_successor" method that
# makes an IMPLICIT_RETURN edge?
return parent
# elif isinstance(parent, ast.CaseClause):
else:
raise Exception(parent)
else:
return self.successor_map[node]
return None
#
# Entrypoint
#
def build(self, node):
assert isinstance(node, SCOPE_NODES)
self.parents.append(node)
self.visit_node(node)
return self.graph
#
# Statement Blocks
#
def visit_statement_list(self, statements):
"""
Visit a list of statements, tracking successor nodes along the way.
"""
last_index = len(statements) - 1
for i, statement in enumerate(statements):
if i < last_index:
self.successor_map[statement] = self.get_entry(statements[i + 1])
self.visit(statement)
def visit_Block(self, node):
self.visit_statement_list(node.statements)
def visit_ExpressionStatement(self, node):
if may_throw(node.expression):
self.connect_node_to_exception_handler(node)
self.add_edge(node, self.get_successor(node))
def visit_WithStatement(self, node):
self.generic_visit(node)
self.add_edge(node, self.get_entry(node.statement))
def visit_EmptyStatement(self, node):
self.generic_visit(node)
self.add_edge(node, self.get_successor(node))
#
# Jumps
#
def get_labelled_jump_target(self, node, target_types):
"""
Find the first parent with a matching label and type.
"""
target_label = node.label
for parent in reversed(self.parents):
if isinstance(parent, ast.LabelledStatement) and\
parent.label == target_label and\
isinstance(parent, target_types): # Check statement type?
return parent
def get_jump_target(self, node, target_types):
"""
Find the first parent that is an instance of one of the target types.
"""
for parent in reversed(self.parents):
if isinstance(parent, target_types):
return parent
return parent
def visit_Break(self, node):
BREAK_TARGETS = (
ast.ForStatement, ast.ForInStatement, ast.DoWhile,
ast.WhileStatement, ast.SwitchStatement
)
LABELLED_BREAK_TARGETS = BREAK_TARGETS + (
ast.Block, ast.IfStatement, ast.TryStatement
)
if node.label:
target = self.get_labelled_jump_target(node, LABELLED_BREAK_TARGETS)
else:
# If unlabelled, we need to find the next sibling of the target
# statement?
target = self.get_jump_target(node, BREAK_TARGETS)
successor = self.get_successor(target)
if successor is not None:
self.add_edge(node, successor, BREAK)
def visit_Continue(self, node):
CONTINUE_TARGETS = (
ast.ForStatement, ast.ForInStatement, ast.DoWhile, ast.WhileStatement
)
if node.label:
target = self.get_labelled_jump_target(node, CONTINUE_TARGETS)
successor = target.statement
else:
target = self.get_jump_target(node, CONTINUE_TARGETS)
if isinstance(target, ast.ForStatement):
successor = target.next
elif isinstance(target, ast.ForInStatement):
successor = target.each
else:
successor = target.condition
if successor is not None:
self.add_edge(node, successor, CONTINUE)
def visit_Return(self, node):
RETURN_TARGETS = SCOPE_NODES
expression = node.expression
if may_throw(expression):
self.connect_node_to_exception_handler(expression)
target = self.get_jump_target(node, RETURN_TARGETS)
self.add_edge(node, target, RETURN)
#
# Conditional
#
def visit_IfStatement(self, node):
self.generic_visit(node)
condition = node.condition
if may_throw(condition):
self.connect_node_to_exception_handler(condition)
then_statement = node.then_statement
else_statement = node.else_statement
self.add_edge(condition, self.get_entry(then_statement), label=TRUE)
if else_statement:
false_successor = self.get_entry(else_statement)
else:
false_successor = self.get_successor(node)
self.add_edge(condition, false_successor, label=FALSE)
def visit_SwitchStatement(self, node):
# Connect object to each case with TRUE edge?
# If default, connect with FALSE edge as well??
# How to handle fallthrough?
self.generic_visit(node)
#
# Looping
#
def visit_ForStatement(self, node):
self.generic_visit(node)
initialize = node.initialize
condition = node.condition
next = node.next
for expression in (initialize, condition, next):
if may_throw(expression):
self.connect_node_to_exception_handler(expression)
self.add_edge(initialize, condition)
self.add_edge(next, condition)
self.add_edge(condition, self.get_entry(node.body), TRUE)
self.add_edge(condition, self.get_successor(node), FALSE)
def visit_ForInStatement(self, node):
self.generic_visit(node)
each = node.each
enumerable = node.enumerable
for expression in (each, enumerable):
if may_throw(expression):
self.connect_node_to_exception_handler(expression)
self.add_edge(enumerable, self.get_entry(node.body), TRUE)
self.add_edge(enumerable, self.get_successor(node), FALSE)
def visit_WhileStatement(self, node):
self.generic_visit(node)
condition = node.condition
if may_throw(condition):
self.connect_node_to_exception_handler(condition)
self.add_edge(node.condition, self.get_entry(node.body), TRUE)
self.add_edge(node.condition, self.get_successor(node), FALSE)
def visit_DoWhileStatement(self, node):
self.generic_visit(node)
condition = node.condition
if may_throw(condition):
self.connect_node_to_exception_handler(condition)
self.add_edge(condition, self.get_entry(node.body), TRUE)
self.add_edge(condition, self.get_successor(node), FALSE)
#
# Exceptions
#
def get_exception_handler(self, node):
exception_context = (
ast.FunctionExpression, ast.FunctionDeclaration,
ast.Program, TryBlock, CatchBlock
)
parent = self.get_parent(node)
while not isinstance(parent, exception_context):
parent = self.get_parent(parent)
assert parent is not None
if isinstance(parent, TryBlock):
try_statement = self.get_parent(parent)
if try_statement.catch_block:
return self.get_entry(try_statement.catch_block)
else:
return self.get_entry(try_statement.finally_block)
elif isinstance(parent, CatchBlock):
try_statement = self.get_parent(parent)
if try_statement.finally_block:
return self.get_entry(try_statement.finally_block)
else:
return self.get_exception_handler(try_statement)
return parent
def connect_node_to_exception_handler(self, from_node):
handler = self.get_exception_handler(from_node)
self.add_edge(from_node, handler, THROW)
def visit_Throw(self, node):
self.connect_node_to_exception_handler(node)
def visit_TryStatement(self, node):
"""
Make sure that the catch and finally blocks get wired up appropriately.
"""
# This will probably require synthetic blocks to be created,
# so connect_exception_handler will need to be reworked to respect
# that business
try_block = node.try_block
catch_block = node.catch_block
finally_block = node.finally_block
# We create some synthetic blocks here to make successor computation
# simpler
if try_block:
try_block = TryBlock(try_block.statements)
self.push(try_block)
self.visit_Block(try_block)
self.pop()
if catch_block:
catch_block = CatchBlock(catch_block.statements)
self.push(catch_block)
self.visit_Block(catch_block)
self.pop()
if finally_block:
finally_block = FinallyBlock(catch_block.statements)
self.push(finally_block)
self.visit_Block(finally_block)
self.pop()
test_cases = [
'function x() { print(arguments); }', # Empty
'if (x) { print(1); } else { print(0); }',
'if (x) print(1); else print(0);',
'if (x) { print(1); }',
'if (x) ; else print(0);',
'while (x) { x--; }',
'do { x--; } while (x)',
'for (var x=0; x<10; x++) { print(x); }',
'for (var x in enumerable) { print(x); }',
'try { call(); } catch(e) { print(e); }',
'try { call(); } finally { print("Error"); }',
'try { call(); } catch(e) { print(e); } finally { print("Error"); }',
'while (x) { if (x < 0) break; x--; }',
]
def to_dot(graph, name=None):
import pydot
from .code_consumer import print_string
name_map = {}
def get_name(node):
if not node in name_map:
if isinstance(node, (ast.ExpressionStatement, ast.Expression)):
label = print_string(node)
else:
label = node.__class__.__name__
if hasattr(node, 'locator') and node.locator is not None:
name = '%d, %d: %s' % (node.locator.line, node.locator.column, label)
else:
name = '%d: %s' % (id(node), label)
name_map[node] = name
return name_map[node]
if name:
dotgraph = pydot.Dot(name, graph_type='digraph', strict=True)
else:
dotgraph = pydot.Dot(graph_type='digraph', strict=True)
for from_node, to_node, attributes in graph.edges_iter(data=True):
# Get node names from nodes
from_name = get_name(from_node)
to_name = get_name(to_node)
dotedge = pydot.Edge(from_name, to_name, **attributes)
dotgraph.add_edge(dotedge)
return dotgraph.to_string()
def print_to_dotfile(graph, filename, name=None):
with open(filename, 'wb') as fd:
fd.write(to_dot(graph))
def print_cfg_for_program(source_string, filename):
from .locator_parser import parse_string
ast = parse_string(source_string)
v = ControlFlowGraphVisitor()
cfg = v.build(ast)
print_to_dotfile(cfg, filename)
@mishoo
Copy link

mishoo commented Nov 17, 2012

#Line 137:
elif isinstance(statement, ast.SwitchStatement): # ????

that should be the switch expression, I guess, maybe statement.expression or whatever it's called in the AST.

I'm interested to do something like this for UglifyJS. Still digging through the literature..

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment