Skip to content

Instantly share code, notes, and snippets.

@math314
Created May 28, 2014 11:37
Show Gist options
  • Save math314/f2fb2ada355dfc619c4d to your computer and use it in GitHub Desktop.
Save math314/f2fb2ada355dfc619c4d to your computer and use it in GitHub Desktop.
#coding=utf-8
import sys,os
import re
from collections import defaultdict,namedtuple
from itertools import *
def scc(edges,redges):
def dfs(v):
used.add(v)
if v in edges:
for t in edges[v]:
if t not in used:
dfs(t)
d1.append(v)
def rdfs(v,l):
used.add(v)
l.append(v)
if v in redges:
for t in redges[v]:
if t not in used:
rdfs(t,l)
used = set()
d1 = []
for v in edges:
if v not in used:
dfs(v)
used = set()
ret = []
for v in reversed(d1):
if v not in used:
l = []
rdfs(v,l)
ret.append(l)
return ret
def lexical_order_toporogycal_sort(_nodes,edges):
nodes = list(sorted(_nodes))
in_rank = {node:0 for node in nodes}
for es in edges:
for t in es:
in_rank[t] += 1
used = set()
ret = []
for _ in xrange(len(nodes)):
select = None
for node in nodes:
if node not in used and in_rank[node] == 0:
select = node
break
ret.append(select)
used.add(select)
for t in edges[select]:
in_rank[t] -= 1
return ret
class Asm(object):
def __init__(self,address,binary,opecode,operand):
self.address = address
self.binary = binary
self.opecode = opecode
self.operand = operand
self.indent = 0
self.comment = ''
self.attribute = []
def next_address(self):
if self.opecode in ('ret','jmp'): return None
else: return self.address + len(self.binary)
def jump_to(self):
if not self.opecode.startswith('j'): return None
return int(self.operand.split()[0],16)
def __str__(self):
indent = ' ' * self.indent
address = "%x" % self.address
binary = ' '.join(['%02x' % i for i in self.binary])
opecode = self.opecode if self.opecode is not None else ''
operand = self.operand if self.operand is not None else ''
comment = self.comment if self.comment is not None else ''
attribute = str(self.attribute) if len(self.attribute) != 0 else ''
ret = []
# if 'if' in self.attribute:
# ret.append(' ' * (self.indent - 1) + 'if(state) {')
# elif 'for' in self.attribute:
# ret.append(' ' * (self.indent - 1) + 'for(init; cond; loop-expression) {')
# elif 'while' in self.attribute:
# ret.append(' ' * (self.indent - 1) + 'while(true) {')
# elif 'else' in self.attribute:
# ret.append(' ' * (self.indent - 1) + '} else {')
# end_blacket_indent = self.indent - 1
# for attr in self.attribute:
# if attr in ('end if','end for','end while'): end_blacket_indent += 1
# for attr in self.attribute:
# if attr in ('end if','end for','end while'):
# ret.append(' ' * end_blacket_indent + '} //' + attr);
# end_blacket_indent -= 1
ret.append("%s%s:\t%-22s\t%s\t%s; %s %s" % \
(indent,address,binary,opecode,operand,comment,attribute) \
)
# if self.opecode[0] == 'j' and len(self.attribute) == 0:
# ret[-1] += 'no attr?'
return '\n'.join(ret)
class BasicBlockNode(object):
def __init__(self,nodes):
self._nodes = nodes
self._address_list = list(chain.from_iterable(node.address_list() for node in nodes))
self._next_to = self.init_next_to()
self.depth = None
def init_next_to(self):
next_list = []
for node in self._nodes:
node_next = node.next_to()
for to in node_next:
if not self.contain_address(to):
next_list.append(to)
return next_list
def head_address(self):
return self._nodes[0].head_address()
def address_list(self):
return self._address_list
def contain_address(self,address):
return address in self._address_list
def get_nodes(self):
return self._nodes
def next_to(self):
return self._next_to
class BasicBlockLeaf(object):
def __init__(self,asms):
self._asms = asms
self.depth = None
self.attribute = []
def head_address(self):
return self._asms[0].address
def address_list(self):
return [asm.address for asm in self._asms]
def contain_address(self,address):
return address in self.address_list()
def next_to(self):
l = []
jump_to = self._asms[-1].jump_to()
if jump_to and jump_to not in self.address_list(): l.append(jump_to)
next_address = self._asms[-1].next_address()
if next_address: l.append(next_address)
return l
def jump_to(self):
return self._asms[-1].jump_to()
def has_attributes(self, *attributes):
for attr in attributes:
if attr in self.attribute: return True
return False
def __str__(self):
l = []
if self.attribute:
l.append(str(self.attribute))
l += [str(asm) for asm in self._asms]
return '\n'.join( (' ' + i for i in l ) )
class Function(object):
def __init__(self,root_node):
self._root_node = root_node
self._basic_blocks = self._expand_graph(root_node, -1)
self._address_map = {bb.head_address(): bb for bb in self._basic_blocks}
self._address_order = {bb.head_address(): i for i,bb in enumerate(self._basic_blocks)}
self._prev_basic_blocks = {bb.head_address(): prev_bb for bb,prev_bb in zip(self._basic_blocks[1:],self._basic_blocks)}
self._prev_basic_blocks[self._basic_blocks[0]] = None #先頭より前は存在しない
self._jump_to_retrun_destination = None
def get_address_order(self,address):
return self._address_order[address]
def get_prev_basic_blocks(self,address):
return self._prev_basic_blocks[address]
def get_basic_blocks(self):
return self._basic_blocks
def _expand_graph(self,root,depth):
#print '\n'.join(' ' * depth + i for i in str(root).split('\n'))
root.depth = depth
if isinstance(root,BasicBlockLeaf):
return [root]
else:
return list(chain.from_iterable(self._expand_graph(node, depth + 1) for node in root.get_nodes()))
def _set_loop_attribute(self):
def dfs(node):
if isinstance(node,BasicBlockLeaf): return
children = node.get_nodes()
if node.depth >= 0:
assert isinstance(children[0],BasicBlockLeaf)
assert isinstance(children[-1],BasicBlockLeaf)
children[0].attribute.append('loop')
children[-1].attribute.append('end loop')
if len(children[-1].next_to()) == 2:
children[-1].attribute.append('break')
for child in children:
dfs(child)
dfs(self._root_node)
def _set_return_attribute(self):
self._basic_blocks[-1].attribute.append('return')
def _set_continue_statement(self):
def dfs(node):
if isinstance(node,BasicBlockLeaf): return
children = node.get_nodes()
if node.depth >= 0 and 'loop' in children[0].attribute:
prev_bb = set()
for child in children[:-1]:
prev_bb.add(child.head_address())
jump_to = child.jump_to()
if jump_to and jump_to in prev_bb:
child.attribute.append('continue')
for child in children:
dfs(child)
dfs(self._root_node)
def _set_break_statement(self):
for bb in self._basic_blocks:
jump_to = bb.jump_to()
if jump_to is None: continue
if self.get_address_order(jump_to) < self.get_address_order(jump_to):
continue #手前に戻った場合breakではない
jump_bb = self._address_map[jump_to]
if bb.depth <= jump_bb.depth:
continue #深さが同じなら、if,飛んだ先の深さがより深ければfor
prev_bb = self.get_prev_basic_blocks(jump_to)
# if jump_bb.depth + 1 != bb.depth or 'end loop' not in prev_bb.attribute
if jump_bb.depth + 1 < bb.depth:
# 深さがおかしい -> return先へのjumpだと考えられる
# 要調整
if self._jump_to_retrun_destination is not None:
assert self._jump_to_retrun_destination == jump_to
else:
self._jump_to_retrun_destination = jump_to
# print 'seems to "jump to ret dest" is %x' % jump_to
if jump_to == self._jump_to_retrun_destination:
# return先へのjumpであることが分かっているので…
bb.attribute.append('jump to return')
continue
else:
#breakだった
bb.attribute.append('break')
def _set_begin_loop_statement(self):
def dfs(node):
if isinstance(node,BasicBlockLeaf):
return []
loop_addresses = []
children = node.get_nodes()
for child in children:
loop_addresses += dfs(child)
for child in children:
if not isinstance(child,BasicBlockLeaf): continue
jump_to = child.jump_to()
if jump_to is None: continue
if jump_to in loop_addresses:
if child.has_attributes('end loop', 'continue', 'break', 'jump to return'):
assert False
loop_node = filter(lambda child2: child2.contain_address(jump_to), children)[0]
loop_begin = loop_node.get_nodes()[0]
loop_end = loop_node.get_nodes()[-1]
# loop_begin.attribute.remove('loop')
# loop_end.attribute.remove('end loop')
# loop_begin.attribute.append('for')
# loop_end.attribute.append('end for')
# child.attribute.append('jump to for')
child.attribute.append('jump to loop')
inner_loop_addresses = []
for child in children:
if isinstance(child,BasicBlockLeaf):
inner_loop_addresses += child.address_list()
return inner_loop_addresses
dfs(self._root_node)
def _rebuild_with_if_statement(self):
def get_begin_if_basic_block(nodes):
for i in len(nodes):
node = nodes[i]
if not isinstance(node,BasicBlockLeaf): continue
jump_to = node.jump_to()
if jump_to is None: continue
if node.has_attributes('end loop', 'continue', 'break', 'jump to return','jump to loop'):
continue
if jump_to == self._jump_to_retrun_destination:
continue
for j in xrange(i+1,len(nodes)):
after_node = nodes[j]
if not isinstance(after_node,BasicBlockLeaf): continue
if after_node.contain_address(jump_to):
assert i + 1 != j # 無意味なjmpはないと考える
return (i+1,j-1)
return (None,None)
def dfs(node):
if isinstance(node,BasicBlockLeaf): return
while True:
children = node.get_nodes()
l,r = get_begin_if_basic_block(children)
if l is None:
break
dfs(child)
dfs(self._root_node)
def analyze(self):
#1. return 検出
self._set_return_attribute()
#2. loop 検出
self._set_loop_attribute()
#3. loop continue 検出
self._set_continue_statement()
#4. break 検出
self._set_break_statement()
# #5-1. begin loop 検出
self._set_begin_loop_statement()
# #6 if検出
self._rebuild_with_if_statement()
# todo jump to return 検出,switch-case,末尾再帰の確認
# 末尾最適化がかかっている場合、 g()が末尾でf()を呼ぶケースが考えられる
# dfsで関数を分離するようにしても g()
print str(self)
def dot_format(self):
w = []
for k,v in self._address_map.items():
w.append('"%x"[label="%s"];' % \
(k,''.join(i + '\\l' for i in str(v).split('\n'))) \
)
w.append('"%s" -> "%x";' % ('start',self._basic_blocks[0].head_address()) )
for node in self._basic_blocks:
for to in node.next_to():
w.append('"%x" -> "%x";' % (node.head_address(),to) )
return 'digraph asm {\n node [shape = box];\n%s \n}' % '\n'.join([' ' + i for i in w])
def __str__(self):
l = []
l.append('begin function...')
for leaf in self._basic_blocks:
indent = ' ' * leaf.depth
l.append('\n'.join(indent + i for i in str(leaf).split('\n')))
l.append('end function...')
return '\n\n'.join(l)
def load_assembler(file_name):
part = re.compile(r"\s(?P<address>\w+):\s+(?P<binary>(\w\w\s)+)\s+((?P<opecode>\w+)(\s+(?P<operand>.*))?)?$")
data = [i.strip('\n') for i in open(file_name)]
ret = []
i = 0
text_section = False
while i < len(data):
line = data[i]
if line == "Disassembly of section .text:":
text_section = True
i += 3
continue
if text_section:
if line == "":
break
m = part.match(line)
address = int(m.group("address"),16)
binary = [int(j,16) for j in m.group("binary").split()]
opecode = m.group("opecode")
operand = m.group("operand")
asm = Asm(address,binary,opecode,operand)
ret.append(asm)
i += 1
return ret
def divide_to_function(asms):
next_instruction_address = asms[0].address
current_graph = []
for asm in asms:
if next_instruction_address < asm.address:
yield current_graph
current_graph = []
next_instruction_address = asm.address
current_graph.append(asm)
if asm.opecode == 'ret':
continue
elif asm.opecode == 'jmp':
jump_to = int(asm.operand.split()[0],16)
next_instruction_address = max(next_instruction_address,jump_to)
elif asm.opecode.startswith('j'):
jump_to = int(asm.operand.split()[0],16)
to = max(jump_to,asm.next_address())
next_instruction_address = max(next_instruction_address,to)
else:
next_instruction_address = max(next_instruction_address,asm.next_address())
if current_graph:
yield current_graph
def divide_to_asm_block(asms):
feature_addresses = []
for asm in asms:
if asm.jump_to() is not None:
feature_addresses.append(asm.jump_to())
feature_addresses.append(asm.address + len(asm.binary))
feature_addresses = list(sorted(set(feature_addresses)))
i = 0
l = []
for asm in asms:
if i < len(feature_addresses) and feature_addresses[i] == asm.address:
#print i,"%x" % l[-1].address
yield BasicBlockLeaf(l)
l = []
i += 1
l.append(asm)
if l:
yield BasicBlockLeaf(l)
def build_graph(asm_nodes):
if len(asm_nodes) == 1:
return BasicBlockNode(asm_nodes)
asm_nodes = sorted(asm_nodes,key=lambda node: node.head_address())
edges = defaultdict(list)
redges = defaultdict(list)
head_address = asm_nodes[0].head_address()
def add_edge(v,t):
if t != head_address:
edges[v].append(t)
redges[t].append(v)
for node in asm_nodes:
for to in node.next_to():
exists = False
for _node in asm_nodes:
if to == _node.head_address():
exists = True
break
if exists:
add_edge(node.head_address(),to)
address_to_node = {node.head_address(): node for node in asm_nodes}
edges = dict(edges)
redges = dict(redges)
ret = []
for result_nodes_addresses in scc(edges,redges):
result_nodes = map(lambda address: address_to_node[address],result_nodes_addresses)
if len(result_nodes) == 1:
node = result_nodes[0]
jump_to = node.jump_to()
if jump_to is not None and node.contain_address(jump_to):
ret.append(BasicBlockNode([node])) #自己ループのあるbasic_block
else:
ret.append(node)
else:
ret.append(build_graph(result_nodes))
ret.sort(key=lambda blocks: blocks.head_address())
# ret = lexical_order_toporogycal_sort(ret)
return BasicBlockNode(ret)
def analyze_function(asms):
asm_nodes = [asm_block for asm_block in divide_to_asm_block(asms)]
root_node = build_graph(asm_nodes)
function = Function(root_node)
# for node in sorted_nodes:
# print '--'
# print str(node)
# print '--'
#function.dot_format()
function.analyze()
return function
def main():
asms = load_assembler(sys.argv[1])
for func in divide_to_function(asms):
#print "-" * 30
analyzed = analyze_function(func)
#print "-" * 30
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment