Created
July 19, 2013 18:38
-
-
Save eltjpm/6041381 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Index: numba/environment.py | |
=================================================================== | |
--- numba/environment.py (revision 80669) | |
+++ numba/environment.py (revision 80670) | |
@@ -41,6 +41,7 @@ | |
'update_signature', | |
'create_lfunc1', | |
'NormalizeASTStage', | |
+ 'TransformBuiltinLoops', | |
'ControlFlowAnalysis', | |
#'ConstFolding', | |
'TypeInfer', | |
@@ -74,6 +75,7 @@ | |
default_type_infer_pipeline_order = [ | |
'ast3to2', | |
+ 'TransformBuiltinLoops', | |
'ControlFlowAnalysis', | |
'TypeInfer', | |
] | |
Index: numba/control_flow/control_flow.py | |
=================================================================== | |
--- numba/control_flow/control_flow.py (revision 80669) | |
+++ numba/control_flow/control_flow.py (revision 80670) | |
@@ -773,8 +773,8 @@ | |
warn_unused=warn_unused) | |
# TODO: Generate fake RHS for for iteration target variable | |
- elif (isinstance(lhs, ast.Attribute) and self.flow.block and | |
- assignment is not None): | |
+ elif (isinstance(lhs, (ast.Attribute, nodes.TempStoreNode)) and | |
+ self.flow.block and assignment is not None): | |
self.flow.block.stats.append(AttributeAssignment(assignment)) | |
if self.flow.exceptions: | |
Index: numba/nodes/tempnodes.py | |
=================================================================== | |
--- numba/nodes/tempnodes.py (revision 80669) | |
+++ numba/nodes/tempnodes.py (revision 80670) | |
@@ -50,7 +50,7 @@ | |
def __init__(self, temp, invariant=False): | |
self.temp = temp | |
self.type = temp.type | |
- self.variable = Variable(self.type) | |
+ self.variable = temp.variable | |
self.invariant = invariant | |
def __repr__(self): | |
Index: numba/pipeline.py | |
=================================================================== | |
--- numba/pipeline.py (revision 80669) | |
+++ numba/pipeline.py (revision 80670) | |
@@ -402,6 +402,12 @@ | |
env) | |
return transform.visit(ast) | |
+class TransformBuiltinLoops(PipelineStage): | |
+ def transform(self, ast, env): | |
+ transform = self.make_specializer(loops.TransformBuiltinLoops, ast, | |
+ env) | |
+ return transform.visit(ast) | |
+ | |
#---------------------------------------------------------------------------- | |
# Specializing/Lowering Transforms | |
#---------------------------------------------------------------------------- | |
Index: numba/specialize/loops.py | |
=================================================================== | |
--- numba/specialize/loops.py (revision 79617) | |
+++ numba/specialize/loops.py (revision 83202) | |
@@ -2,6 +2,10 @@ | |
from __future__ import print_function, division, absolute_import | |
import ast | |
import textwrap | |
+try: | |
+ import __builtin__ as builtins | |
+except ImportError: | |
+ import builtins | |
import numba | |
from numba import * | |
@@ -57,6 +61,12 @@ | |
while_node = nodes.build_while(**vars(while_node)) | |
return while_node | |
+def untypedTemp(): | |
+ "Temp node with a yet unknown type" | |
+ type = typesystem.DeferredType(None) | |
+ temp = nodes.TempNode(type) | |
+ type.variable = temp.variable | |
+ return temp | |
#------------------------------------------------------------------------ | |
# Transform for loops | |
@@ -261,6 +265,145 @@ | |
return node | |
#------------------------------------------------------------------------ | |
+# Transform for loops over builtins | |
+#------------------------------------------------------------------------ | |
+ | |
+class TransformBuiltinLoops(visitors.NumbaTransformer): | |
+ def rewrite_enumerate(self, node): | |
+ """ | |
+ Rewrite a loop like | |
+ | |
+ for i, x in enumerate(array[, start]): | |
+ ... | |
+ | |
+ into | |
+ | |
+ _arr = array | |
+ [_s = start] | |
+ for _i in range(len(_arr)): | |
+ i = _i [+ _s] | |
+ x = _arr[_i] | |
+ ... | |
+ """ | |
+ call = node.iter | |
+ if (len(call.args) not in (1, 2) or call.keywords or | |
+ call.starargs or call.kwargs): | |
+ self.error(call, 'expected 1 or 2 arguments to enumerate()') | |
+ | |
+ target = node.target | |
+ if (not isinstance(target, (ast.Tuple, ast.List)) or | |
+ len(target.elts) != 2): | |
+ self.error(call, 'expected 2 iteration variables') | |
+ | |
+ array = call.args[0] | |
+ start = call.args[1] if len(call.args) > 1 else None | |
+ idx = target.elts[0] | |
+ var = target.elts[1] | |
+ | |
+ array_temp = untypedTemp() | |
+ if start: | |
+ start_temp = untypedTemp() # TODO: only allow integer start | |
+ idx_temp = nodes.TempNode(typesystem.Py_ssize_t) | |
+ | |
+ # for _i in range(len(_arr)): | |
+ node.target = idx_temp.store() | |
+ node.iter = ast.Call(ast.Name('range', ast.Load()), | |
+ [ast.Call(ast.Name('len', ast.Load()), | |
+ [array_temp.load(True)], | |
+ [], None, None)], | |
+ [], None, None) | |
+ | |
+ # i = _i [+ _s] | |
+ new_idx = idx_temp.load() | |
+ if start: | |
+ new_idx = ast.BinOp(new_idx, ast.Add(), start_temp.load(True)) | |
+ node.body.insert(0, ast.Assign([idx], new_idx)) | |
+ | |
+ # x = _arr[_i] | |
+ value = ast.Subscript(array_temp.load(True), | |
+ ast.Index(idx_temp.load()), | |
+ ast.Load()) | |
+ node.body.insert(1, ast.Assign([var], value)) | |
+ | |
+ # _arr = array; [_s = start]; ... | |
+ body = [ ast.Assign([array_temp.store()], array), node ] | |
+ if start: | |
+ body.insert(1, ast.Assign([start_temp.store()], start)) | |
+ return map(self.visit, body) | |
+ | |
+ def rewrite_zip(self, node): | |
+ """ | |
+ Rewrite a loop like | |
+ | |
+ for x, y... in zip(xs, ys...): | |
+ ... | |
+ | |
+ into | |
+ | |
+ _xs = xs; _ys = ys... | |
+ for _i in range(min(len(_xs), len(_ys)...)): | |
+ x = _xs[_i]; y = _ys[_i]... | |
+ ... | |
+ """ | |
+ call = node.iter | |
+ if not call.args or call.keywords or call.starargs or call.kwargs: | |
+ self.error(call, 'expected at least 1 argument to zip()') | |
+ | |
+ target = node.target | |
+ if (not isinstance(target, (ast.Tuple, ast.List)) or | |
+ len(target.elts) != len(call.args)): | |
+ self.error(call, 'expected %d iteration variables' % len(call.args)) | |
+ | |
+ temps = [untypedTemp() for _ in xrange(len(call.args))] | |
+ idx_temp = nodes.TempNode(typesystem.Py_ssize_t) | |
+ | |
+ # min(len(_xs), len(_ys)...) | |
+ len_call = ast.Call(ast.Name('min', ast.Load()), | |
+ [ast.Call(ast.Name('len', ast.Load()), | |
+ [tmp.load(True)], [], None, None) | |
+ for tmp in temps], | |
+ [], None, None) | |
+ | |
+ # for _i in range(...): | |
+ node.target = idx_temp.store() | |
+ node.iter = ast.Call(ast.Name('range', ast.Load()), | |
+ [len_call], [], None, None) | |
+ | |
+ # x = _xs[_i]; y = _ys[_i]... | |
+ node.body = [ast.Assign([tgt], | |
+ ast.Subscript(tmp.load(True), | |
+ ast.Index(idx_temp.load()), | |
+ ast.Load())) | |
+ for tgt, tmp in zip(target.elts, temps)] + \ | |
+ node.body | |
+ | |
+ # _xs = xs; _ys = ys... | |
+ body = [ast.Assign([tmp.store()], arg) | |
+ for tmp, arg in zip(temps, call.args)] + \ | |
+ [node] | |
+ return map(self.visit, body) | |
+ | |
+ HANDLERS = { | |
+ id(enumerate): rewrite_enumerate, | |
+ id(zip): rewrite_zip, | |
+ } | |
+ | |
+ def visit_For(self, node): | |
+ if (isinstance(node.iter, ast.Call) and | |
+ isinstance(node.iter.func, ast.Name)): | |
+ name = node.iter.func.id | |
+ if name not in self.symtab: | |
+ obj = (self.func_globals[name] | |
+ if name in self.func_globals else | |
+ getattr(builtins, name, None)) | |
+ rewriter = self.HANDLERS.get(id(obj)) | |
+ if rewriter: | |
+ return rewriter(self, node) | |
+ | |
+ self.visitchildren(node) | |
+ return node | |
+ | |
+#------------------------------------------------------------------------ | |
# Transform for loops over Objects | |
#------------------------------------------------------------------------ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment