Created
September 23, 2014 13:15
-
-
Save serge-sans-paille/79b44dd89f374c96b20f to your computer and use it in GitHub Desktop.
Python - functional style!
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ast | |
import sys | |
import shutil | |
import unparse | |
import unittest | |
import doctest | |
import StringIO | |
import os | |
from copy import deepcopy | |
def _chain(l, s): | |
""" | |
chain all lambdas from l, starting with expression s | |
""" | |
def _combine(x, y): | |
return ast.Call(y, [x], [], None, None) | |
return reduce(_combine, l, s) | |
class NotSupportedError(RuntimeError): | |
pass | |
class GatherIdentifiers(ast.NodeVisitor): | |
def __init__(self): | |
self.result = [] | |
def visit_Name(self, node): | |
if type(node.ctx) is ast.Param: | |
self.result.append(node) | |
class IsRegular(ast.NodeVisitor): | |
def __init__(self): | |
self.result = True | |
def irregular(self, node): | |
self.result = False | |
visit_Break = visit_Continue = visit_Return = irregular | |
def isregular(node): | |
ir = IsRegular() | |
ir.visit(node) | |
return ir.result | |
class FunctionalStyle(ast.NodeTransformer): | |
""" | |
Turns a function into a lambda expression | |
The whole idea is to turn a function into a lambda expression that should | |
be complex enough to understand to prevent straight forward desobfuscation | |
To do so two operators are introduced: | |
E -> (expr -> store) -> new_expr | |
S -> (stmt -> store) -> (store -> store) | |
E(expr, store) returns the functionnal version of the expression `expr' | |
when evaluated in store `store'. | |
The result is an expression. | |
S(stmt, store) returns a function that takes a store as input | |
and returns a new store | |
The `nesting_level' parameter is only there for debugging purpose. | |
Set it to one simulates the processing of an instruction inside a function. | |
""" | |
def __init__(self, nesting_level=0): | |
self.rec = '__' | |
self.store = "_" | |
self.return_ = "$" | |
self.wtmp = "!" | |
self.formal_rec = 'f' | |
self.formal_store = '_' | |
args = ast.arguments([ast.Name(self.formal_rec, ast.Param()), | |
ast.Name(self.formal_store, ast.Param())], | |
None, None, []) | |
body = ast.Call(ast.Name(self.formal_rec, ast.Load()), | |
[ast.Name(self.formal_rec, ast.Load()), | |
ast.Name(self.formal_store, ast.Load())], | |
[], | |
None, | |
None) | |
self.ycombinator = ast.Lambda(args, body) | |
self.nesting_level = nesting_level | |
def not_supported(self, node): | |
if self.nesting_level: | |
raise NotSupportedError(str(type(node))) | |
else: | |
return node | |
visit_ClassDef = not_supported | |
visit_Print = not_supported | |
visit_With = not_supported | |
visit_Raise = not_supported | |
visit_TryExcept = not_supported | |
visit_TryFinally = not_supported | |
visit_Assert = not_supported | |
visit_ImportFrom = not_supported | |
visit_Exec = not_supported | |
visit_Global = not_supported | |
visit_Break = not_supported | |
visit_Continue = not_supported | |
visit_Yield = not_supported | |
visit_Lambda = not_supported | |
def visit_FunctionDef(self, node): | |
""" | |
A function is turned into a lambda declaration | |
>>> node = ast.parse('def foo(x): pass') | |
>>> newnode = FunctionalStyle().visit(node) | |
>>> _ = unparse.Unparser(newnode, sys.stdout) | |
<BLANKLINE> | |
foo = (lambda x: (lambda _: _)({'x': x, '$': None})['$']) | |
""" | |
if self.nesting_level: | |
raise NotSupportedError("Nested Functions") | |
if node.decorator_list: | |
return node | |
if node.args.vararg: | |
return node | |
if node.args.kwarg: | |
return node | |
if node.args.defaults: | |
return node | |
# gather all the function to chain | |
nesting_level = self.nesting_level | |
self.nesting_level += 1 | |
try: | |
orig = deepcopy(node) | |
calls = map(self.visit, node.body) | |
except NotSupportedError: | |
self.nesting_level = nesting_level | |
return orig | |
self.nesting_level -= 1 | |
# create the initial state | |
gi = GatherIdentifiers() | |
map(gi.visit, node.args.args) | |
formal_parameters = gi.result | |
keys = [ast.Str(n.id) for n in formal_parameters] | |
keys.append(ast.Str(self.return_)) | |
values = [ast.Name(n.id, ast.Load()) for n in formal_parameters] | |
values.append(ast.Name('None', ast.Load())) | |
init_expr = ast.Dict(keys, values) | |
# create the lambda | |
lambda_ = ast.Lambda(node.args, | |
ast.Subscript(_chain(calls, init_expr), | |
ast.Index(ast.Str(self.return_)), | |
ast.Load()) | |
) | |
res = ast.Assign([ast.Name(node.name, ast.Store())], lambda_) | |
return ast.fix_missing_locations(res) | |
def visit_Return(self, node): | |
""" | |
A return just adds an entry in the state and returns the state | |
>>> node = ast.parse('return 1') | |
>>> newnode = FunctionalStyle(nesting_level=1).visit(node) | |
>>> _ = unparse.Unparser(newnode, sys.stdout) | |
(lambda _: (_.__setitem__('$', 1), _)[(-1)]) | |
""" | |
if not self.nesting_level: | |
return node | |
args = ast.arguments([ast.Name(self.store, ast.Param())], | |
None, None, []) | |
if node.value: | |
returned = self.visit(node.value) | |
else: | |
returned = ast.Name("None", ast.Load()) | |
setreturn = ast.Call(ast.Attribute(ast.Name(self.store, ast.Load()), | |
'__setitem__', | |
ast.Load()), | |
[ast.Str(self.return_), returned], | |
[], | |
None, | |
None) | |
body = ast.Subscript(ast.Tuple([setreturn, | |
ast.Name(self.store, ast.Load())], | |
ast.Load()), | |
ast.Index(ast.Num(-1)), | |
ast.Load()) | |
return ast.Lambda(args, body) | |
def visit_Delete(self, node): | |
""" | |
A delete removes entries from the store | |
>>> node = ast.parse('del a') | |
>>> newnode = FunctionalStyle(nesting_level=1).visit(node) | |
>>> _ = unparse.Unparser(newnode, sys.stdout) | |
(lambda _: (_.__delitem__('a'), _)[(-1)]) | |
""" | |
if not self.nesting_level: | |
return node | |
if any(type(t) is not ast.Name for t in node.targets): | |
raise NotSupportedError("deleting non identifiers") | |
args = ast.arguments([ast.Name(self.store, ast.Param())], | |
None, None, []) | |
bodyn = [ast.Call(ast.Attribute(ast.Name(self.store, ast.Load()), | |
'__delitem__', ast.Load()), | |
[ast.Str(t.id)], | |
[], None, None) | |
for t in node.targets] | |
bodyl = ast.Name(self.store, ast.Load()) | |
body = ast.Subscript(ast.Tuple(bodyn + [bodyl], ast.Load()), | |
ast.Index(ast.Num(-1)), | |
ast.Load() | |
) | |
return ast.Lambda(args, body) | |
def visit_Pass(self, node): | |
""" | |
A Pass is similar to applying the indentity to the locals | |
S('pass', state) = lambda state : state | |
>>> node = ast.parse('pass') | |
>>> newnode = FunctionalStyle(nesting_level=1).visit(node) | |
>>> _ = unparse.Unparser(newnode, sys.stdout) | |
(lambda _: _) | |
""" | |
if not self.nesting_level: | |
return node | |
args = ast.arguments([ast.Name(self.store, ast.Param())], | |
None, None, []) | |
body = ast.Name(self.store, ast.Load()) | |
return ast.Lambda(args, body) | |
def visit_While(self, node): | |
""" | |
The definition of a while is recursive! | |
The recursive function itself is | |
S('while cond: body else: orelse', state) = | |
( S('while cond: body else: orelse', S('body', state)) | |
if E('cond', state) | |
else S('orelse', state) ) | |
So with the y combinator we get | |
S('while cond: body else: orelse', state) = | |
Y(lambda self, state: | |
self(self, S('body', state)) | |
if E('cond', state) | |
else S('orelse', state), | |
state) | |
>>> node = ast.parse('while 1: pass') | |
>>> newnode = FunctionalStyle(nesting_level=1).visit(node) | |
>>> _ = unparse.Unparser(newnode, sys.stdout) | |
(lambda _: (lambda f, _: f(f, _))\ | |
((lambda __, _: \ | |
((lambda _: __(__, _))((lambda _: _)(_)) if 1 else _)), \ | |
_)) | |
""" | |
if not self.nesting_level: | |
return node | |
if not isregular(node): | |
raise NotSupportedError("irregular control flow") | |
args = ast.arguments([ast.Name(self.rec, ast.Param()), | |
ast.Name(self.store, ast.Param())], | |
None, None, []) | |
body_ = map(self.visit, node.body) | |
lambda_args = ast.arguments([ast.Name(self.store, ast.Param())], | |
None, None, []) | |
lambda_ = ast.Lambda(lambda_args, | |
ast.Call(ast.Name(self.rec, ast.Load()), | |
[ast.Name(self.rec, ast.Load()), | |
ast.Name(self.store, ast.Load())], | |
[], | |
None, | |
None) | |
) | |
body_.append(lambda_) | |
body_ = _chain(body_, ast.Name(self.store, ast.Load())) | |
orelse_ = map(self.visit, node.orelse) | |
orelse_ = _chain(orelse_, ast.Name(self.store, ast.Load())) | |
body = ast.IfExp(self.visit(node.test), body_, orelse_) | |
return ast.Lambda(ast.arguments([ast.Name(self.store, ast.Param())], | |
None, None, []), | |
ast.Call(self.ycombinator, | |
[ast.Lambda(args, body), | |
ast.Name(self.store, ast.Load())], | |
[], | |
None, | |
None) | |
) | |
def visit_AugAssign(self, node): | |
""" | |
An augassign just updates the store | |
>>> node = ast.parse('a += 1') | |
>>> newnode = FunctionalStyle(nesting_level=1).visit(node) | |
>>> _ = unparse.Unparser(newnode, sys.stdout) | |
(lambda _: \ | |
(_.__setitem__('a', ((_['a'] if ('a' in _) else a) + 1)), _)[(-1)]) | |
""" | |
if not self.nesting_level: | |
return node | |
args = ast.arguments([ast.Name(self.store, ast.Param())], | |
None, None, []) | |
load_target = self.visit(node.target) | |
load_target.ctx = ast.Load() | |
op = self.assign_helper(node.target, | |
ast.BinOp(load_target, | |
node.op, | |
self.visit(node.value))) | |
body = ast.Subscript(ast.Tuple([op, | |
ast.Name(self.store, ast.Load())], | |
ast.Load()), | |
ast.Index(ast.Num(-1)), | |
ast.Load()) | |
return ast.Lambda(args, body) | |
def visit_If(self, node): | |
""" | |
An if evaluates its condition then yields one of the branch | |
There is an if expression in python, take advantage of it! | |
>>> node = ast.parse('if 1: 2') | |
>>> newnode = FunctionalStyle(nesting_level=1).visit(node) | |
>>> _ = unparse.Unparser(newnode, sys.stdout) | |
(lambda _: ((lambda _: (2, _)[(-1)])(_) if 1 else _)) | |
""" | |
if not self.nesting_level: | |
return node | |
if not isregular(node): | |
raise NotSupportedError("irregular control flow") | |
args = ast.arguments([ast.Name(self.store, ast.Param())], | |
None, None, []) | |
body_ = map(self.visit, node.body) | |
body_ = _chain(body_, ast.Name(self.store, ast.Load())) | |
orelse_ = map(self.visit, node.orelse) | |
orelse_ = _chain(orelse_, ast.Name(self.store, ast.Load())) | |
body = ast.IfExp(self.visit(node.test), body_, orelse_) | |
return ast.Lambda(args, body) | |
def assign_helper(self, target, value): | |
# assigning to a name is easy | |
if type(target) is ast.Name: | |
return ast.Call(ast.Attribute(ast.Name(self.store, ast.Load()), | |
'__setitem__', ast.Load()), | |
[ast.Str(target.id), value], | |
[], None, None) | |
# but assigning to a subscript is tricky | |
elif type(target) is ast.Subscript: | |
# really tricky: there are different types of slices | |
tslice = type(target.slice) | |
if tslice is ast.Index: | |
vslice = self.visit(target.slice.value) | |
else: | |
raise NotSupportedError("complex slices") | |
return ast.Call(ast.Attribute(self.visit(target.value), | |
'__setitem__', ast.Load()), | |
[vslice, value], | |
[], None, None) | |
else: | |
raise NotSupportedError("Assigning to something" | |
"not a subscript not a name") | |
def visit_Assign(self, node): | |
""" | |
An assign creates one or several entries in the store | |
Type destructuring is not supported | |
>>> node = ast.parse('a = 2') | |
>>> newnode = FunctionalStyle(nesting_level=1).visit(node) | |
>>> _ = unparse.Unparser(newnode, sys.stdout) | |
(lambda _: (_.__setitem__('!', 2), \ | |
_.__setitem__('a', _['!']), _)[(-1)]) | |
""" | |
if not self.nesting_level: | |
return node | |
if any(type(t) is ast.Tuple for t in node.targets): | |
raise NotSupportedError("type destructuring") | |
args = ast.arguments([ast.Name(self.store, ast.Param())], | |
None, None, []) | |
body0 = ast.Call(ast.Attribute(ast.Name(self.store, ast.Load()), | |
'__setitem__', ast.Load()), | |
[ast.Str(self.wtmp), self.visit(node.value)], | |
[], None, None) | |
value = ast.Subscript(ast.Name(self.store, ast.Load()), | |
ast.Index(ast.Str(self.wtmp)), | |
ast.Load()) | |
bodyn = [self.assign_helper(t, deepcopy(value)) for t in node.targets] | |
bodyl = ast.Name(self.store, ast.Load()) | |
body = ast.Subscript(ast.Tuple([body0] + bodyn + [bodyl], ast.Load()), | |
ast.Index(ast.Num(-1)), | |
ast.Load() | |
) | |
return ast.Lambda(args, body) | |
def visit_Name(self, node): | |
""" | |
When visiting a name, we don't know statically whether it is | |
- a local name, in which case it should be looked up in the store | |
- a global name, in which case it should be looked up in the globals | |
Moreover, one cannot use the globals() function: | |
it may be monkey patched | |
>>> node = ast.parse('i') | |
>>> newnode = FunctionalStyle(nesting_level=1).visit(node) | |
>>> _ = unparse.Unparser(newnode, sys.stdout) | |
(lambda _: ((_['i'] if ('i' in _) else i), _)[(-1)]) | |
""" | |
if not self.nesting_level: | |
return node | |
cond = ast.Compare(ast.Str(node.id), | |
[ast.In()], | |
[ast.Name(self.store, ast.Load())]) | |
body_ = ast.Subscript(ast.Name(self.store, ast.Load()), | |
ast.Index(ast.Str(node.id)), | |
ast.Load()) | |
orelse_ = node | |
return ast.IfExp(cond, body_, orelse_) | |
def visit_For(self, node): | |
""" | |
A for loop can be emulated using list comprehension | |
It assumes there is no break or continue, though | |
>>> node = ast.parse('for i in []: pass') | |
>>> newnode = FunctionalStyle(nesting_level=1).visit(node) | |
>>> _ = unparse.Unparser(newnode, sys.stdout) | |
(lambda _: ([(lambda _: _)(_) for _['i'] in []], _, _)[(-1)]) | |
""" | |
if not self.nesting_level: | |
return node | |
if not isregular(node): | |
raise NotSupportedError("irregular control flow") | |
if type(node.target) is not ast.Name: | |
raise NotSupportedError("only identifiers as loop index") | |
args = ast.arguments([ast.Name(self.store, ast.Param())], | |
None, None, []) | |
# turn the for into a lis comp | |
body_ = map(self.visit, node.body) | |
body_ = _chain(body_, ast.Name(self.store, ast.Load())) | |
orelse_ = map(self.visit, node.orelse) | |
orelse_ = _chain(orelse_, ast.Name(self.store, ast.Load())) | |
target_ = ast.Subscript(ast.Name(self.store, ast.Load()), | |
ast.Index(ast.Str(node.target.id)), | |
ast.Store()) | |
comp = ast.ListComp(body_, [ast.comprehension(target_, | |
self.visit(node.iter), | |
[])]) | |
# combine the orelse statemnt | |
body = ast.Subscript(ast.Tuple([comp, orelse_, | |
ast.Name(self.store, ast.Load())], | |
ast.Load()), | |
ast.Index(ast.Num(-1)), | |
ast.Load() | |
) | |
return ast.Lambda(args, body) | |
def visit_Import(self, node): | |
""" | |
Emulate import using the __import__ function | |
This is slightly fragile, as one could have monkey patched it | |
>>> node = ast.parse('import math') | |
>>> newnode = FunctionalStyle(nesting_level=1).visit(node) | |
>>> _ = unparse.Unparser(newnode, sys.stdout) | |
(lambda _: (_.__setitem__('math', __import__('math')), _)[(-1)]) | |
""" | |
if not self.nesting_level: | |
return node | |
args = ast.arguments([ast.Name(self.store, ast.Param())], | |
None, None, []) | |
bodyn = [ast.Call(ast.Attribute(ast.Name(self.store, ast.Load()), | |
'__setitem__', ast.Load()), | |
[ast.Str(n.asname or n.name), | |
ast.Call(ast.Name('__import__', ast.Load()), | |
[ast.Str(n.name)], | |
[], None, None)], | |
[], None, None | |
) | |
for n in node.names] | |
bodyl = ast.Name(self.store, ast.Load()) | |
body = ast.Subscript(ast.Tuple(bodyn + [bodyl], ast.Load()), | |
ast.Index(ast.Num(-1)), | |
ast.Load() | |
) | |
return ast.Lambda(args, body) | |
def visit_Expr(self, node): | |
""" | |
An expression just needs to be wrapped in a lambda | |
>>> node = ast.parse('1') | |
>>> newnode = FunctionalStyle(nesting_level=1).visit(node) | |
>>> _ = unparse.Unparser(newnode, sys.stdout) | |
(lambda _: (1, _)[(-1)]) | |
""" | |
if not self.nesting_level: | |
return node | |
args = ast.arguments([ast.Name(self.store, ast.Param())], | |
None, None, []) | |
body = ast.Subscript(ast.Tuple([self.visit(node.value), | |
ast.Name(self.store, ast.Load())], | |
ast.Load()), | |
ast.Index(ast.Num(-1)), | |
ast.Load()) | |
return ast.Lambda(args, body) | |
def visit_comprehension(self, node): | |
if type(node.target) is not ast.Name: | |
raise NotSupportedError("only identifiers as loop index") | |
target_ = ast.Subscript(ast.Name(self.store, ast.Load()), | |
ast.Index(ast.Str(node.target.id)), | |
ast.Store()) | |
return ast.comprehension(target_, | |
self.visit(node.iter), | |
map(self.visit, node.ifs)) | |
class TestFunctionalStyle(unittest.TestCase): | |
def generic_test(self, code, *tests): | |
ref_env = globals().copy() | |
exec code in ref_env | |
for test in tests: | |
# generate reference | |
ref = eval(test, ref_env) | |
# parse, transform and eval | |
node = ast.parse(code) | |
node = FunctionalStyle().visit(node) | |
obj = compile(node, '<test>', 'exec') | |
obj_env = globals().copy() | |
exec obj in obj_env | |
candidate = eval(test, obj_env) | |
self.assertEqual(ref, candidate) | |
# also test that generated string can be compiled | |
out = StringIO.StringIO() | |
unparse.Unparser(node, out) | |
ast.parse(out.getvalue()) | |
def test_FunctionDef(self): | |
self.generic_test("def foo(x): return x", | |
"foo(1)", "foo(1.5)", "foo('hello')") | |
self.generic_test("def foo(x,y): return x,y", | |
"foo(1, True)", "foo(.5, {})", "foo('h', (0, None))") | |
def test_Pass(self): | |
self.generic_test("def foo(): pass", 'foo()') | |
def test_Return(self): | |
self.generic_test("def foo(x): return", "foo(0)") | |
self.generic_test("def foo(x): return x + 1", "foo(0)") | |
def test_AugAssign(self): | |
self.generic_test("def foo(x, y): x += y ; return x, y", | |
"foo(1,2)") | |
self.generic_test("def foo(x, y): x[y] += y ; return x, y", | |
"foo([1, 2, 3], 2)") | |
def test_Assign(self): | |
self.generic_test("def foo(x, y): x = y ; return x, y", | |
"foo(1,2)") | |
self.generic_test("def foo(x, y): x = y = x * y; return x, y", | |
"foo('1.4',2)") | |
self.generic_test("def foo(x, y): x[y] = y ; return x", | |
"foo([1, '3'], 1)") | |
self.generic_test("def foo(x, y): x[y][0][0] = y ; return x", | |
"foo([1, [['3']]], 1)") | |
def test_If(self): | |
self.generic_test("def foo(x, y):\n if x: return x\n else: return y", | |
"foo(1,2)", "foo(0, 2)") | |
self.generic_test(""" | |
def foo(x, y): | |
if x: return x | |
elif y: return y | |
else: return 'e'""", | |
"foo([1], [])", | |
"foo([], [1])", | |
"foo([], [])") | |
def test_While(self): | |
self.generic_test("def foo(x):\n while x: pass\n return x", | |
"foo(0)") | |
self.generic_test("def f(x):\n while x: x-=1\n else: x+=1\n return x", | |
"f(1)", | |
"f(0)") | |
self.generic_test("def foo(x):\n while x>0: x-=1\n return x", | |
"foo(3)", | |
"foo(0)") | |
def test_For(self): | |
self.generic_test("def foo(x,s):\n for i in x: s+=i;\n return s", | |
"foo('hello', '')", | |
"foo([1,2,3], 8)") | |
def test_Del(self): | |
self.generic_test("def foo(x): del x", "foo(1)") | |
def test_Import(self): | |
self.generic_test("def foo(x): import math as m ; return m.cos(x)", | |
"foo(1)") | |
def test_Expr(self): | |
self.generic_test("def foo(x, y): x(y); return y", | |
"foo(lambda x: x.append(1),[])") | |
def test_For(self): | |
self.generic_test("def foo(x, y):\n for i in x: y+= 1\n return y", | |
"foo('hello', 0)") | |
def test_Global(self): | |
self.generic_test("def foo(x): return range(x)", | |
"foo(3)") | |
self.generic_test("range = list\ndef foo(x): return range(x)", | |
"foo('e')") | |
def test_bootstrap(self): | |
# Verify we correctly process ourselves | |
module = sys.modules[__name__] | |
module_code = file(module.__file__).read() | |
module_node = ast.parse(module_code) | |
module_node = FunctionalStyle().visit(module_node) | |
module_obj = compile(module_node, '<test>', 'exec') | |
sys.path.append(os.path.dirname(__file__)) | |
env = {} | |
exec module_obj in env | |
def test_on_ast(self): | |
# Verify we correctly process the ast module | |
module_code = file(ast.__file__[:-1]).read() | |
module_node = ast.parse(module_code) | |
module_node = FunctionalStyle().visit(module_node) | |
module_obj = compile(module_node, '<test>', 'exec') | |
env = {} | |
exec module_obj in env | |
def transform(input_path, output_path): | |
try: | |
with open(input_path) as input_file: | |
node = ast.parse(input_file.read()) | |
FunctionalStyle().visit(node) | |
with open(output_path, 'w') as output_file: | |
output_file.write('#! /usr/bin/env python\n') | |
unparse.Unparser(node, output_file) | |
output_file.write('\n') | |
shutil.copymode(input_path, output_path) | |
shutil.copystat(input_path, output_path) | |
except SyntaxError: | |
pass | |
if __name__ == "__main__": | |
if len(sys.argv) < 2 or len(sys.argv) > 3: | |
print 'Usage: %s <input file> [output file]' % sys.argv[0] | |
exit(0) | |
if len(sys.argv) >= 2: | |
input_name = sys.argv[1] | |
if len(sys.argv) == 3: | |
output = open(sys.argv[2], 'w') | |
else: | |
output = sys.stdout | |
input_file = open(input_name, 'r') | |
node = ast.parse(input_file.read()) | |
node = FunctionalStyle().visit(node) | |
unparse.Unparser(node, output) | |
output.write('\n') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"Usage: unparse.py <path to source file>" | |
import sys | |
import ast | |
import cStringIO | |
import os | |
# Large float and imaginary literals get turned into infinities in the AST. | |
# We unparse those infinities to INFSTR. | |
INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1) | |
def interleave(inter, f, seq): | |
"""Call f on each item in seq, calling inter() in between. | |
""" | |
seq = iter(seq) | |
try: | |
f(next(seq)) | |
except StopIteration: | |
pass | |
else: | |
for x in seq: | |
inter() | |
f(x) | |
class Unparser: | |
"""Methods in this class recursively traverse an AST and | |
output source code for the abstract syntax; original formatting | |
is disregarded. """ | |
def __init__(self, tree, file = sys.stdout): | |
"""Unparser(tree, file=sys.stdout) -> None. | |
Print the source for tree to file.""" | |
self.f = file | |
self.future_imports = [] | |
self._indent = 0 | |
self.dispatch(tree) | |
self.f.write("") | |
self.f.flush() | |
def fill(self, text = ""): | |
"Indent a piece of text, according to the current indentation level" | |
self.f.write("\n"+" "*self._indent + text) | |
def write(self, text): | |
"Append a piece of text to the current line." | |
self.f.write(text) | |
def enter(self): | |
"Print ':', and increase the indentation." | |
self.write(":") | |
self._indent += 1 | |
def leave(self): | |
"Decrease the indentation level." | |
self._indent -= 1 | |
def dispatch(self, tree): | |
"Dispatcher function, dispatching tree type T to method _T." | |
if isinstance(tree, list): | |
for t in tree: | |
self.dispatch(t) | |
return | |
meth = getattr(self, "_"+tree.__class__.__name__) | |
meth(tree) | |
############### Unparsing methods ###################### | |
# There should be one method per concrete grammar type # | |
# Constructors should be grouped by sum type. Ideally, # | |
# this would follow the order in the grammar, but # | |
# currently doesn't. # | |
######################################################## | |
def _Module(self, tree): | |
for stmt in tree.body: | |
self.dispatch(stmt) | |
# stmt | |
def _Expr(self, tree): | |
self.fill() | |
self.dispatch(tree.value) | |
def _Import(self, t): | |
self.fill("import ") | |
interleave(lambda: self.write(", "), self.dispatch, t.names) | |
def _ImportFrom(self, t): | |
# A from __future__ import may affect unparsing, so record it. | |
if t.module and t.module == '__future__': | |
self.future_imports.extend(n.name for n in t.names) | |
self.fill("from ") | |
self.write("." * t.level) | |
if t.module: | |
self.write(t.module) | |
self.write(" import ") | |
interleave(lambda: self.write(", "), self.dispatch, t.names) | |
def _Assign(self, t): | |
self.fill() | |
for target in t.targets: | |
self.dispatch(target) | |
self.write(" = ") | |
self.dispatch(t.value) | |
def _AugAssign(self, t): | |
self.fill() | |
self.dispatch(t.target) | |
self.write(" "+self.binop[t.op.__class__.__name__]+"= ") | |
self.dispatch(t.value) | |
def _Return(self, t): | |
self.fill("return") | |
if t.value: | |
self.write(" ") | |
self.dispatch(t.value) | |
def _Pass(self, t): | |
self.fill("pass") | |
def _Break(self, t): | |
self.fill("break") | |
def _Continue(self, t): | |
self.fill("continue") | |
def _Delete(self, t): | |
self.fill("del ") | |
interleave(lambda: self.write(", "), self.dispatch, t.targets) | |
def _Assert(self, t): | |
self.fill("assert ") | |
self.dispatch(t.test) | |
if t.msg: | |
self.write(", ") | |
self.dispatch(t.msg) | |
def _Exec(self, t): | |
self.fill("exec ") | |
self.dispatch(t.body) | |
if t.globals: | |
self.write(" in ") | |
self.dispatch(t.globals) | |
if t.locals: | |
self.write(", ") | |
self.dispatch(t.locals) | |
def _Print(self, t): | |
self.fill("print ") | |
do_comma = False | |
if t.dest: | |
self.write(">>") | |
self.dispatch(t.dest) | |
do_comma = True | |
for e in t.values: | |
if do_comma:self.write(", ") | |
else:do_comma=True | |
self.dispatch(e) | |
if not t.nl: | |
self.write(",") | |
def _Global(self, t): | |
self.fill("global ") | |
interleave(lambda: self.write(", "), self.write, t.names) | |
def _Yield(self, t): | |
self.write("(") | |
self.write("yield") | |
if t.value: | |
self.write(" ") | |
self.dispatch(t.value) | |
self.write(")") | |
def _Raise(self, t): | |
self.fill('raise ') | |
if t.type: | |
self.dispatch(t.type) | |
if t.inst: | |
self.write(", ") | |
self.dispatch(t.inst) | |
if t.tback: | |
self.write(", ") | |
self.dispatch(t.tback) | |
def _TryExcept(self, t): | |
self.fill("try") | |
self.enter() | |
self.dispatch(t.body) | |
self.leave() | |
for ex in t.handlers: | |
self.dispatch(ex) | |
if t.orelse: | |
self.fill("else") | |
self.enter() | |
self.dispatch(t.orelse) | |
self.leave() | |
def _TryFinally(self, t): | |
if len(t.body) == 1 and isinstance(t.body[0], ast.TryExcept): | |
# try-except-finally | |
self.dispatch(t.body) | |
else: | |
self.fill("try") | |
self.enter() | |
self.dispatch(t.body) | |
self.leave() | |
self.fill("finally") | |
self.enter() | |
self.dispatch(t.finalbody) | |
self.leave() | |
def _ExceptHandler(self, t): | |
self.fill("except") | |
if t.type: | |
self.write(" ") | |
self.dispatch(t.type) | |
if t.name: | |
self.write(" as ") | |
self.dispatch(t.name) | |
self.enter() | |
self.dispatch(t.body) | |
self.leave() | |
def _ClassDef(self, t): | |
self.write("\n") | |
for deco in t.decorator_list: | |
self.fill("@") | |
self.dispatch(deco) | |
self.fill("class "+t.name) | |
if t.bases: | |
self.write("(") | |
for a in t.bases: | |
self.dispatch(a) | |
self.write(", ") | |
self.write(")") | |
self.enter() | |
self.dispatch(t.body) | |
self.leave() | |
def _FunctionDef(self, t): | |
self.write("\n") | |
for deco in t.decorator_list: | |
self.fill("@") | |
self.dispatch(deco) | |
self.fill("def "+t.name + "(") | |
self.dispatch(t.args) | |
self.write(")") | |
self.enter() | |
self.dispatch(t.body) | |
self.leave() | |
def _For(self, t): | |
self.fill("for ") | |
self.dispatch(t.target) | |
self.write(" in ") | |
self.dispatch(t.iter) | |
self.enter() | |
self.dispatch(t.body) | |
self.leave() | |
if t.orelse: | |
self.fill("else") | |
self.enter() | |
self.dispatch(t.orelse) | |
self.leave() | |
def _If(self, t): | |
self.fill("if ") | |
self.dispatch(t.test) | |
self.enter() | |
self.dispatch(t.body) | |
self.leave() | |
# collapse nested ifs into equivalent elifs. | |
while (t.orelse and len(t.orelse) == 1 and | |
isinstance(t.orelse[0], ast.If)): | |
t = t.orelse[0] | |
self.fill("elif ") | |
self.dispatch(t.test) | |
self.enter() | |
self.dispatch(t.body) | |
self.leave() | |
# final else | |
if t.orelse: | |
self.fill("else") | |
self.enter() | |
self.dispatch(t.orelse) | |
self.leave() | |
def _While(self, t): | |
self.fill("while ") | |
self.dispatch(t.test) | |
self.enter() | |
self.dispatch(t.body) | |
self.leave() | |
if t.orelse: | |
self.fill("else") | |
self.enter() | |
self.dispatch(t.orelse) | |
self.leave() | |
def _With(self, t): | |
self.fill("with ") | |
self.dispatch(t.context_expr) | |
if t.optional_vars: | |
self.write(" as ") | |
self.dispatch(t.optional_vars) | |
self.enter() | |
self.dispatch(t.body) | |
self.leave() | |
# expr | |
def _Str(self, tree): | |
# if from __future__ import unicode_literals is in effect, | |
# then we want to output string literals using a 'b' prefix | |
# and unicode literals with no prefix. | |
if "unicode_literals" not in self.future_imports: | |
self.write(repr(tree.s)) | |
elif isinstance(tree.s, str): | |
self.write("b" + repr(tree.s)) | |
elif isinstance(tree.s, unicode): | |
self.write(repr(tree.s).lstrip("u")) | |
else: | |
assert False, "shouldn't get here" | |
def _Name(self, t): | |
self.write(t.id) | |
def _Repr(self, t): | |
self.write("`") | |
self.dispatch(t.value) | |
self.write("`") | |
def _Num(self, t): | |
repr_n = repr(t.n) | |
# Parenthesize negative numbers, to avoid turning (-1)**2 into -1**2. | |
if repr_n.startswith("-"): | |
self.write("(") | |
# Substitute overflowing decimal literal for AST infinities. | |
self.write(repr_n.replace("inf", INFSTR)) | |
if repr_n.startswith("-"): | |
self.write(")") | |
def _List(self, t): | |
self.write("[") | |
interleave(lambda: self.write(", "), self.dispatch, t.elts) | |
self.write("]") | |
def _ListComp(self, t): | |
self.write("[") | |
self.dispatch(t.elt) | |
for gen in t.generators: | |
self.dispatch(gen) | |
self.write("]") | |
def _GeneratorExp(self, t): | |
self.write("(") | |
self.dispatch(t.elt) | |
for gen in t.generators: | |
self.dispatch(gen) | |
self.write(")") | |
def _SetComp(self, t): | |
self.write("{") | |
self.dispatch(t.elt) | |
for gen in t.generators: | |
self.dispatch(gen) | |
self.write("}") | |
def _DictComp(self, t): | |
self.write("{") | |
self.dispatch(t.key) | |
self.write(": ") | |
self.dispatch(t.value) | |
for gen in t.generators: | |
self.dispatch(gen) | |
self.write("}") | |
def _comprehension(self, t): | |
self.write(" for ") | |
self.dispatch(t.target) | |
self.write(" in ") | |
self.dispatch(t.iter) | |
for if_clause in t.ifs: | |
self.write(" if ") | |
self.dispatch(if_clause) | |
def _IfExp(self, t): | |
self.write("(") | |
self.dispatch(t.body) | |
self.write(" if ") | |
self.dispatch(t.test) | |
self.write(" else ") | |
self.dispatch(t.orelse) | |
self.write(")") | |
def _Set(self, t): | |
assert(t.elts) # should be at least one element | |
self.write("{") | |
interleave(lambda: self.write(", "), self.dispatch, t.elts) | |
self.write("}") | |
def _Dict(self, t): | |
self.write("{") | |
def write_pair(pair): | |
(k, v) = pair | |
self.dispatch(k) | |
self.write(": ") | |
self.dispatch(v) | |
interleave(lambda: self.write(", "), write_pair, zip(t.keys, t.values)) | |
self.write("}") | |
def _Tuple(self, t): | |
self.write("(") | |
if len(t.elts) == 1: | |
(elt,) = t.elts | |
self.dispatch(elt) | |
self.write(",") | |
else: | |
interleave(lambda: self.write(", "), self.dispatch, t.elts) | |
self.write(")") | |
unop = {"Invert":"~", "Not": "not", "UAdd":"+", "USub":"-"} | |
def _UnaryOp(self, t): | |
self.write("(") | |
self.write(self.unop[t.op.__class__.__name__]) | |
self.write(" ") | |
# If we're applying unary minus to a number, parenthesize the number. | |
# This is necessary: -2147483648 is different from -(2147483648) on | |
# a 32-bit machine (the first is an int, the second a long), and | |
# -7j is different from -(7j). (The first has real part 0.0, the second | |
# has real part -0.0.) | |
if isinstance(t.op, ast.USub) and isinstance(t.operand, ast.Num): | |
self.write("(") | |
self.dispatch(t.operand) | |
self.write(")") | |
else: | |
self.dispatch(t.operand) | |
self.write(")") | |
binop = { "Add":"+", "Sub":"-", "Mult":"*", "Div":"/", "Mod":"%", | |
"LShift":"<<", "RShift":">>", "BitOr":"|", "BitXor":"^", "BitAnd":"&", | |
"FloorDiv":"//", "Pow": "**"} | |
def _BinOp(self, t): | |
self.write("(") | |
self.dispatch(t.left) | |
self.write(" " + self.binop[t.op.__class__.__name__] + " ") | |
self.dispatch(t.right) | |
self.write(")") | |
cmpops = {"Eq":"==", "NotEq":"!=", "Lt":"<", "LtE":"<=", "Gt":">", "GtE":">=", | |
"Is":"is", "IsNot":"is not", "In":"in", "NotIn":"not in"} | |
def _Compare(self, t): | |
self.write("(") | |
self.dispatch(t.left) | |
for o, e in zip(t.ops, t.comparators): | |
self.write(" " + self.cmpops[o.__class__.__name__] + " ") | |
self.dispatch(e) | |
self.write(")") | |
boolops = {ast.And: 'and', ast.Or: 'or'} | |
def _BoolOp(self, t): | |
self.write("(") | |
s = " %s " % self.boolops[t.op.__class__] | |
interleave(lambda: self.write(s), self.dispatch, t.values) | |
self.write(")") | |
def _Attribute(self,t): | |
self.dispatch(t.value) | |
# Special case: 3.__abs__() is a syntax error, so if t.value | |
# is an integer literal then we need to either parenthesize | |
# it or add an extra space to get 3 .__abs__(). | |
if isinstance(t.value, ast.Num) and isinstance(t.value.n, int): | |
self.write(" ") | |
self.write(".") | |
self.write(t.attr) | |
def _Call(self, t): | |
self.dispatch(t.func) | |
self.write("(") | |
comma = False | |
for e in t.args: | |
if comma: self.write(", ") | |
else: comma = True | |
self.dispatch(e) | |
for e in t.keywords: | |
if comma: self.write(", ") | |
else: comma = True | |
self.dispatch(e) | |
if t.starargs: | |
if comma: self.write(", ") | |
else: comma = True | |
self.write("*") | |
self.dispatch(t.starargs) | |
if t.kwargs: | |
if comma: self.write(", ") | |
else: comma = True | |
self.write("**") | |
self.dispatch(t.kwargs) | |
self.write(")") | |
def _Subscript(self, t): | |
self.dispatch(t.value) | |
self.write("[") | |
self.dispatch(t.slice) | |
self.write("]") | |
# slice | |
def _Ellipsis(self, t): | |
self.write("...") | |
def _Index(self, t): | |
self.dispatch(t.value) | |
def _Slice(self, t): | |
if t.lower: | |
self.dispatch(t.lower) | |
self.write(":") | |
if t.upper: | |
self.dispatch(t.upper) | |
if t.step: | |
self.write(":") | |
self.dispatch(t.step) | |
def _ExtSlice(self, t): | |
interleave(lambda: self.write(', '), self.dispatch, t.dims) | |
# others | |
def _arguments(self, t): | |
first = True | |
# normal arguments | |
defaults = [None] * (len(t.args) - len(t.defaults)) + t.defaults | |
for a,d in zip(t.args, defaults): | |
if first:first = False | |
else: self.write(", ") | |
self.dispatch(a), | |
if d: | |
self.write("=") | |
self.dispatch(d) | |
# varargs | |
if t.vararg: | |
if first:first = False | |
else: self.write(", ") | |
self.write("*") | |
self.write(t.vararg) | |
# kwargs | |
if t.kwarg: | |
if first:first = False | |
else: self.write(", ") | |
self.write("**"+t.kwarg) | |
def _keyword(self, t): | |
self.write(t.arg) | |
self.write("=") | |
self.dispatch(t.value) | |
def _Lambda(self, t): | |
self.write("(") | |
self.write("lambda ") | |
self.dispatch(t.args) | |
self.write(": ") | |
self.dispatch(t.body) | |
self.write(")") | |
def _alias(self, t): | |
self.write(t.name) | |
if t.asname: | |
self.write(" as "+t.asname) | |
def roundtrip(filename, output=sys.stdout): | |
with open(filename, "r") as pyfile: | |
source = pyfile.read() | |
tree = compile(source, filename, "exec", ast.PyCF_ONLY_AST) | |
Unparser(tree, output) | |
def testdir(a): | |
try: | |
names = [n for n in os.listdir(a) if n.endswith('.py')] | |
except OSError: | |
sys.stderr.write("Directory not readable: %s" % a) | |
else: | |
for n in names: | |
fullname = os.path.join(a, n) | |
if os.path.isfile(fullname): | |
output = cStringIO.StringIO() | |
print 'Testing %s' % fullname | |
try: | |
roundtrip(fullname, output) | |
except Exception as e: | |
print ' Failed to compile, exception is %s' % repr(e) | |
elif os.path.isdir(fullname): | |
testdir(fullname) | |
def main(args): | |
if args[0] == '--testdir': | |
for a in args[1:]: | |
testdir(a) | |
else: | |
for a in args: | |
roundtrip(a) | |
if __name__=='__main__': | |
main(sys.argv[1:]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment