Created
July 19, 2013 18:44
-
-
Save eltjpm/6041422 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Index: numba/specialize/loops.py | |
=================================================================== | |
--- numba/specialize/loops.py (revision 79617) | |
+++ numba/specialize/loops.py (revision 83202) | |
@@ -95,36 +105,31 @@ | |
else: | |
have_step = True | |
- start, stop, step = [nodes.CloneableNode(n) | |
- for n in (start, stop, step)] | |
+ start, stop, step = map(nodes.CloneableNode, (start, stop, step)) | |
if have_step: | |
- compute_nsteps = """ | |
- $length = {{stop}} - {{start}} | |
- {{nsteps}} = $length / {{step}} | |
- if {{nsteps_load}} * {{step}} != $length: #$length % {{step}}: | |
- # Test for truncation | |
- {{nsteps}} = {{nsteps_load}} + 1 | |
- # print "nsteps", {{nsteps_load}} | |
- """ | |
+ templ = textwrap.dedent(""" | |
+ {{temp}} = 0 | |
+ {{nsteps}} = ({{stop}} - {{start}} + {{step}} - | |
+ (1 if {{step}} >= 0 else -1)) / {{step}} | |
+ while {{temp_load}} < {{nsteps_load}}: | |
+ {{target}} = {{start}} + {{temp_load}} * {{step}} | |
+ {{body}} | |
+ {{temp}} = {{temp_load}} + 1 | |
+ """) | |
else: | |
- compute_nsteps = "{{nsteps}} = {{stop}} - {{start}}" | |
+ templ = textwrap.dedent(""" | |
+ {{temp}} = {{start}} | |
+ {{nsteps}} = {{stop}} | |
+ while {{temp_load}} < {{nsteps_load}}: | |
+ {{target}} = {{temp_load}} | |
+ {{body}} | |
+ {{temp}} = {{temp_load}} + 1 | |
+ """) | |
if node.orelse: | |
- else_clause = "else: {{else_body}}" | |
- else: | |
- else_clause = "" | |
+ templ += "\nelse: {{else_body}}" | |
- templ = textwrap.dedent(""" | |
- %s | |
- {{temp}} = 0 | |
- while {{temp_load}} < {{nsteps_load}}: | |
- {{target}} = {{start}} + {{temp_load}} * {{step}} | |
- {{body}} | |
- {{temp}} = {{temp_load}} + 1 | |
- %s | |
- """) % (textwrap.dedent(compute_nsteps), else_clause) | |
- | |
# Leave the bodies empty, they are already analyzed | |
body = ast.Suite(body=[]) | |
else_body = ast.Suite(body=[]) | |
@@ -196,8 +201,7 @@ | |
# Replace node.target with a temporary | |
#-------------------------------------------------------------------- | |
- target_name = orig_target.id + '.idx' | |
- target_temp = nodes.TempNode(Py_ssize_t) | |
+ target_temp = nodes.TempNode(typesystem.Py_ssize_t) | |
node.target = target_temp.store() | |
#-------------------------------------------------------------------- |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from numba import autojit | |
import numpy as np | |
import unittest | |
@autojit | |
def for_loop_fn_1 (start, stop, inc): | |
acc = 0 | |
for value in range(start, stop, inc): | |
acc += value | |
return acc | |
@autojit | |
def for_loop_fn_1a (start, stop): | |
acc = 0 | |
for value in range(start, stop): | |
acc += value | |
return acc | |
@autojit | |
def for_loop_fn_1b (stop): | |
acc = 0 | |
for value in range(stop): | |
acc += value | |
return acc | |
class TestForLoop(unittest.TestCase): | |
def test_compiled_for_loop_fn_many(self): | |
for lo in xrange( -10, 11 ): | |
for hi in xrange( -10, 11 ): | |
for step in xrange( -20, 21 ): | |
if step: | |
self.assertEqual(for_loop_fn_1(lo, hi, step), | |
for_loop_fn_1.py_func(lo, hi, step), | |
'failed for %d/%d/%d' % (lo, hi, step)) | |
self.assertEqual(for_loop_fn_1a(lo, hi), | |
for_loop_fn_1a.py_func(lo, hi), | |
'failed for %d/%d' % (lo, hi)) | |
self.assertEqual(for_loop_fn_1b(hi), | |
for_loop_fn_1b.py_func(hi), | |
'failed for %d' % hi) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment