Last active
December 20, 2015 00:19
-
-
Save eltjpm/6041224 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Index: builtinmodule.py | |
=================================================================== | |
--- builtinmodule.py (revision 79617) | |
+++ builtinmodule.py (revision 83202) | |
@@ -4,11 +4,11 @@ | |
""" | |
from __future__ import print_function, division, absolute_import | |
+import ast | |
from numba import * | |
from numba import nodes | |
from numba import error | |
# from numba import function_util | |
-# from numba.specialize.mathcalls import is_math_function | |
from numba.symtab import Variable | |
from numba import typesystem | |
from numba.typesystem import is_obj, promote_closest, get_type | |
@@ -27,6 +27,9 @@ | |
else: | |
return nodes.CoercionNode(node.args[0], dst_type=dst_type) | |
+def binop_type(context, x, y): | |
+ return context.promote_types(get_type(x), get_type(y)) | |
+ | |
#---------------------------------------------------------------------------- | |
# Type Functions for Builtins | |
#---------------------------------------------------------------------------- | |
@@ -75,6 +78,10 @@ | |
def _float(context, node, x): | |
return cast(node, double) | |
+@register_builtin((0, 1), can_handle_deferred_types=True) | |
+def _bool(context, node, x): | |
+ return cast(node, bool_) | |
+ | |
@register_builtin((0, 1, 2), can_handle_deferred_types=True) | |
def complex_(context, node, a, b): | |
if len(node.args) == 2: | |
@@ -100,12 +107,11 @@ | |
@register_builtin((2, 3)) | |
def pow_(context, node, base, exponent, mod): | |
- from . import mathmodule | |
- return mathmodule.pow_(context, node, base, exponent) | |
+ node.variable = Variable(binop_type(context, base, exponent)) | |
+ return node | |
@register_builtin((1, 2)) | |
def round_(context, node, number, ndigits): | |
- # is_math = is_math_function(node.args, round) | |
argtype = get_type(number) | |
if len(node.args) == 1 and argtype.is_int: | |
@@ -121,6 +127,43 @@ | |
node.variable = Variable(dst_type) | |
return node # nodes.CoercionNode(node, double) | |
+def minmax(context, args, op): | |
+ if len(args) < 2: | |
+ return | |
+ | |
+ res = args[0] | |
+ for arg in args[1:]: | |
+ lhs_type = get_type(res) | |
+ rhs_type = get_type(arg) | |
+ res_type = context.promote_types(lhs_type, rhs_type) | |
+ if lhs_type != res_type: | |
+ res = nodes.CoercionNode(res, res_type) | |
+ if rhs_type != res_type: | |
+ arg = nodes.CoercionNode(arg, res_type) | |
+ | |
+ lhs_temp = nodes.TempNode(res_type) | |
+ rhs_temp = nodes.TempNode(res_type) | |
+ res_temp = nodes.TempNode(res_type) | |
+ lhs = lhs_temp.load(invariant=True) | |
+ rhs = rhs_temp.load(invariant=True) | |
+ expr = ast.IfExp(ast.Compare(lhs, [op], [rhs]), lhs, rhs) | |
+ body = [ | |
+ ast.Assign([lhs_temp.store()], res), | |
+ ast.Assign([rhs_temp.store()], arg), | |
+ ast.Assign([res_temp.store()], expr), | |
+ ] | |
+ res = nodes.ExpressionNode(body, res_temp.load(invariant=True)) | |
+ | |
+ return res | |
+ | |
+@register_builtin(None) | |
+def min_(context, node, *args): | |
+ return minmax(context, args, ast.Lt()) | |
+ | |
+@register_builtin(None) | |
+def max_(context, node, *args): | |
+ return minmax(context, args, ast.Gt()) | |
+ | |
@register_builtin(0) | |
def globals_(context, node): | |
return typesystem.dict_ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from numba import autojit | |
import lib.numbatest | |
@autojit | |
def max1(x): | |
""" | |
>>> max1([100]) | |
100 | |
>>> max1([1,2.0,3]) | |
3 | |
>>> max1([-1,-2,-3.0]) | |
-1 | |
>>> max1(1) | |
Traceback (most recent call last): | |
... | |
TypeError: 'int' object is not iterable | |
""" | |
return max(x) | |
@autojit | |
def min1(x): | |
""" | |
>>> min1([100]) | |
100 | |
>>> min1([1,2,3.0]) | |
1 | |
>>> min1([-1,-2.0,-3]) | |
-3 | |
>>> min1(1) | |
Traceback (most recent call last): | |
... | |
TypeError: 'int' object is not iterable | |
""" | |
return min(x) | |
@autojit | |
def max2(x, y): | |
""" | |
>>> max2(1, 2) | |
2 | |
>>> max2(1, -2) | |
1 | |
>>> max2(10, 10.25) | |
10.25 | |
>>> max2(10, 9.9) | |
10.0 | |
>>> max2(0.1, 0.25) | |
0.25 | |
>>> max2(1, 'a') | |
Traceback (most recent call last): | |
... | |
UnpromotableTypeError: (int, const char *) | |
""" | |
return max(x, y) | |
@autojit | |
def min2(x, y): | |
""" | |
>>> min2(1, 2) | |
1 | |
>>> min2(1, -2) | |
-2 | |
>>> min2(10, 10.1) | |
10.0 | |
>>> min2(10, 9.75) | |
9.75 | |
>>> min2(0.25, 0.3) | |
0.25 | |
>>> min2(1, 'a') | |
Traceback (most recent call last): | |
... | |
UnpromotableTypeError: (int, const char *) | |
""" | |
return min(x, y) | |
@autojit | |
def max4(x): | |
""" | |
>>> max4(20) | |
20.0 | |
""" | |
return max(1, 2.0, x, 14) | |
@autojit | |
def min4(x): | |
""" | |
>>> min4(-2) | |
-2.0 | |
""" | |
return min(1, 2.0, x, 14) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment