Created
July 19, 2013 23:21
-
-
Save eltjpm/6043068 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Index: numba/visitors.py | |
=================================================================== | |
--- numba/visitors.py (revision 81089) | |
+++ numba/visitors.py (revision 81090) | |
@@ -364,47 +364,68 @@ | |
function_level = 0 | |
def __init__(self, *args, **kwargs): | |
- self.referenced = {} | |
- self.assigned = {} | |
+ self.params = set() | |
+ self.assigned = set() | |
+ self.referenced = set() | |
+ self.globals = set() | |
self.func_defs = [] | |
- def register_assignment(self, node, target, operator): | |
- if isinstance(target, nodes.MaybeUnusedNode): | |
- target = target.name_node | |
- if isinstance(target, ast.Name): | |
- self.assigned[target.id] = node | |
+ def visit_Name(self, node): | |
+ if isinstance(node.ctx, ast.Load): | |
+ add_to = self.referenced | |
+ elif isinstance(node.ctx, ast.Param): | |
+ add_to = self.params | |
+ else: | |
+ add_to = self.assigned | |
+ add_to.add(node.id) | |
- def visit_Assign(self, node): | |
- self.generic_visit(node) | |
- op = getattr(node, "inplace_op", None) | |
- self.register_assignment(node, node.targets[0], op) | |
+ def visit_Global(self, node): | |
+ self.globals.update(node.names) | |
- def visit_AugAssign(self, node): | |
- self.generic_visit(node) | |
- self.register_assignment(node, node.target, node.op) | |
+ def visit_Import(self, node): | |
+ self.assigned.update((alias.asname or alias.name.split('.', 1)[0]) | |
+ for alias in node.names | |
+ if alias.name != '*') | |
- def visit_For(self, node): | |
- self.generic_visit(node) | |
- self.register_assignment(node, node.target, None) | |
+ visit_ImportFrom = visit_Import | |
- def visit_Name(self, node): | |
- self.referenced[node.id] = node | |
- | |
def visit_FunctionDef(self, node): | |
if self.function_level == 0: | |
+ if node.args.vararg: | |
+ self.params.add(node.args.vararg) | |
+ if node.args.kwarg: | |
+ self.params.add(node.args.kwarg) | |
+ | |
self.function_level += 1 | |
self.generic_visit(node) | |
self.function_level -= 1 | |
else: | |
+ if hasattr(node, 'name'): | |
+ self.assigned.add(node.name) | |
self.func_defs.append(node) | |
- return node | |
+ visit_Lambda = visit_FunctionDef | |
+ def visit_ClassDef(self, node): | |
+ self.assigned.add(node.name) | |
+ self.generic_visit(node) | |
+ | |
def visit_ClosureNode(self, node): | |
+ self.assigned.add(node.name) | |
self.func_defs.append(node) | |
- self.generic_visit(node) | |
- return node | |
+ def visit_GeneratorExp(self, node): | |
+ raise error.NumbaError( | |
+ node, "Generator comprehensions are not yet supported") | |
+ | |
+ def visit_SetComp(self, node): | |
+ raise error.NumbaError( | |
+ node, "Set comprehensions are not yet supported") | |
+ | |
+ def visit_DictComp(self, node): | |
+ raise error.NumbaError( | |
+ node, "Dict comprehensions are not yet supported") | |
+ | |
def determine_variable_status(env, ast, locals_dict): | |
""" | |
Determine what category referenced and assignment variables fall in: | |
@@ -420,14 +441,12 @@ | |
v = VariableFindingVisitor() | |
v.visit(ast) | |
- locals = set(v.assigned) | |
- locals.update(locals_dict) | |
+ if not v.params.isdisjoint(v.globals): | |
+ raise error.NumbaError( | |
+ node, "Parameters cannot be declared global") | |
- locals.update([name.id for name in ast.args.args]) | |
- | |
- locals.update(func_def.name for func_def in v.func_defs) | |
- | |
- freevars = set(v.referenced) - locals | |
+ locals = v.params.union(v.assigned, locals_dict) - v.globals | |
+ freevars = v.referenced - locals | |
cellvars = set() | |
# Compute cell variables | |
@@ -440,12 +459,14 @@ | |
inner_locals_dict) | |
cellvars.update(locals.intersection(inner_freevars)) | |
-# print ast.name, "locals", pformat(locals), \ | |
-# "cellvars", pformat(cellvars), \ | |
-# "freevars", pformat(freevars), \ | |
-# "locals_dict", pformat(locals_dict) | |
-# print ast.name, "locals", pformat(locals) | |
+# from pprint import pformat | |
+# print(ast.name, "locals", pformat(locals), | |
+# "cellvars", pformat(cellvars), | |
+# "freevars", pformat(freevars), | |
+# "locals_dict", pformat(locals_dict)) | |
+# print(ast.name, "locals", pformat(locals)) | |
# Cache state | |
annotate(env, ast, variable_status_tuple=(locals, cellvars, freevars)) | |
return locals, cellvars, freevars | |
+ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment