Created
July 23, 2013 01:00
-
-
Save eltjpm/6059053 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Index: numba/transforms.py | |
=================================================================== | |
--- numba/transforms.py (revision 83292) | |
+++ numba/transforms.py (working copy) | |
@@ -630,17 +635,16 @@ | |
def visit_MathNode(self, math_node): | |
"Translate a nodes.MathNode to an intrinsic or libc math call" | |
- from numba.type_inference.modules import mathmodule | |
- lowerable = is_math_function([math_node.arg], math_node.py_func) | |
+ lowerable = is_math_function(math_node.args, math_node.py_func) | |
if math_node.type.is_array or not lowerable: | |
# Generate a Python call | |
assert math_node.py_func is not None | |
- result = nodes.call_pyfunc(math_node.py_func, [math_node.arg]) | |
+ result = nodes.call_pyfunc(math_node.py_func, math_node.args) | |
result = result.coerce(math_node.type) | |
else: | |
# Lower to intrinsic or libc math call | |
- args = [math_node.arg], math_node.py_func, math_node.type | |
+ args = math_node.args, math_node.py_func, math_node.type | |
if is_intrinsic(math_node.py_func): | |
result = resolve_intrinsic(*args) | |
else: | |
Index: numba/type_inference/modules/mathmodule.py | |
=================================================================== | |
--- numba/type_inference/modules/mathmodule.py (revision 83292) | |
+++ numba/type_inference/modules/mathmodule.py (working copy) | |
@@ -10,11 +10,6 @@ | |
import math | |
import cmath | |
-try: | |
- import __builtin__ as builtins | |
-except ImportError: | |
- import builtins | |
- | |
import numpy as np | |
from numba import * | |
@@ -30,11 +25,15 @@ | |
register_math_typefunc = utils.register_with_argchecking | |
-def binop_type(context, x, y): | |
- "Binary result type for math operations" | |
- x_type = get_type(x) | |
- y_type = get_type(y) | |
- return context.promote_types(x_type, y_type) | |
+def largest_type(default_result_type, types): | |
+ for type in types: | |
+#TODO: put back array support | |
+# if type.is_array and type.dtype.is_int: | |
+# type = type.copy(dtype=double) | |
+ if type.is_numeric and type.kind > default_result_type.kind: | |
+ default_result_type = type | |
+ | |
+ return default_result_type | |
#---------------------------------------------------------------------------- | |
# Determine math functions | |
@@ -42,39 +41,49 @@ | |
# sin(double), sinf(float), sinl(long double) | |
unary_libc_math_funcs = [ | |
- 'sin', | |
- 'cos', | |
- 'tan', | |
- 'sqrt', | |
'acos', | |
+ 'acosh', | |
'asin', | |
- 'atan', | |
- 'atan2', | |
- 'sinh', | |
- 'cosh', | |
- 'tanh', | |
'asinh', | |
- 'acosh', | |
+ 'atan', | |
'atanh', | |
- 'log', | |
- 'log2', | |
- 'log10', | |
- 'fabs', | |
- 'erfc', | |
'ceil', | |
+ 'cos', | |
+ 'cosh', | |
+ 'erfc', | |
'exp', | |
'exp2', | |
'expm1', | |
- 'rint', | |
+ 'fabs', | |
+ #factorial | |
+ 'floor', | |
+ #'isinf', # linux only -- returns bool | |
+ #'isnan', # -- returns bool | |
+ 'log', | |
+ 'log10', | |
'log1p', | |
+ 'log2', | |
+ #radians | |
+ 'rint', | |
+ 'round', # linux only | |
+ 'sin', | |
+ 'sinh', | |
+ 'sqrt', | |
+ 'tan', | |
+ 'tanh', | |
+ #'trunc', # linux only -- returns int in python | |
] | |
-n_ary_libc_math_funcs = [ | |
+binary_libc_math_funcs = [ | |
+ 'atan2', | |
+ 'copysign', | |
+ 'fmod', | |
+ 'hypot', | |
+ #ldexp -- int argument | |
'pow', | |
- 'round', | |
] | |
-all_libc_math_funcs = unary_libc_math_funcs + n_ary_libc_math_funcs | |
+all_libc_math_funcs = unary_libc_math_funcs + binary_libc_math_funcs | |
#---------------------------------------------------------------------------- | |
# Math Type Inferers | |
@@ -82,44 +91,27 @@ | |
# TODO: Move any rewriting parts to lowering phases | |
-def infer_unary_math_call(context, call_node, arg, default_result_type=double): | |
+def infer_math_or_cmath_call(default_result_type, context, call_node, *args): | |
"Resolve calls to math functions to llvm.log.f32() etc" | |
# signature is a generic signature, build a correct one | |
- type = get_type(call_node.args[0]) | |
- | |
- if type.is_numeric and type.kind < default_result_type.kind: | |
- type = default_result_type | |
- elif type.is_array and type.dtype.is_int: | |
- type = type.copy(dtype=double) | |
- | |
- # signature = minitypes.FunctionType(return_type=type, args=[type]) | |
- # result = nodes.MathNode(py_func, signature, call_node.args[0]) | |
+ type = largest_type(default_result_type, map(get_type, args)) | |
nodes.annotate(context.env, call_node, is_math=True) | |
call_node.variable = Variable(type) | |
return call_node | |
-def infer_unary_cmath_call(context, call_node, arg): | |
- result = infer_unary_math_call(context, call_node, arg, | |
- default_result_type=complex128) | |
+def infer_math_call(context, call_node, *arg): | |
+ return infer_math_or_cmath_call(double, context, call_node, *arg) | |
+ | |
+def infer_cmath_call(context, call_node, *arg): | |
+ result = infer_math_or_cmath_call(complex128, context, call_node, *arg) | |
nodes.annotate(context.env, call_node, is_cmath=True) | |
return result | |
# ______________________________________________________________________ | |
-# pow() | |
- | |
-def pow_(context, call_node, node, power, mod=None): | |
- dst_type = binop_type(context, node, power) | |
- call_node.variable = Variable(dst_type) | |
- return call_node | |
- | |
-register_math_typefunc((2, 3), math.pow) | |
-register_math_typefunc(2, np.power) | |
- | |
-# ______________________________________________________________________ | |
-# abs() | |
+# broken numpy funcs | |
def abs_(context, node, x): | |
- import builtinmodule | |
+ from . import builtinmodule | |
argtype = get_type(x) | |
@@ -132,7 +124,9 @@ | |
return builtinmodule.abs_(context, node, x) | |
-register_math_typefunc(1, np.abs) | |
+#FIXME: this one is broken, core.issues.test_issue_56 fails | |
+#register_math_typefunc(1)(abs_, np.abs) | |
+#register_math_typefunc(2)(infer_binary_math_call, np.power) | |
#---------------------------------------------------------------------------- | |
# Register Type Functions | |
@@ -140,23 +134,25 @@ | |
def register_math(nargs, value): | |
register = register_math_typefunc(nargs) | |
- register(infer_unary_math_call, value) | |
+ register(infer_math_call, value) | |
def register_cmath(nargs, value): | |
register = register_math_typefunc(nargs) | |
- register(infer_unary_cmath_call, value) | |
+ register(infer_cmath_call, value) | |
def register_typefuncs(): | |
- modules = [builtins, math, cmath, np] | |
- # print all_libc_math_funcs | |
- for libc_math_func in unary_libc_math_funcs: | |
- for module in modules: | |
- if hasattr(module, libc_math_func): | |
- if module is cmath: | |
- register = register_cmath | |
- else: | |
- register = register_math | |
+ modules = [math, cmath, np] | |
+ for nargs, libc_math_funcs in [(1, unary_libc_math_funcs), | |
+ (2, binary_libc_math_funcs)]: | |
+ for libc_math_func in libc_math_funcs: | |
+ for module in modules: | |
+ if hasattr(module, libc_math_func): | |
+ if module is cmath: | |
+ register = register_cmath | |
+ else: | |
+ register = register_math | |
- register(1, getattr(module, libc_math_func)) | |
+ register(nargs, getattr(module, libc_math_func)) | |
register_typefuncs() | |
+ | |
Index: numba/specialize/mathcalls.py | |
=================================================================== | |
--- numba/specialize/mathcalls.py (revision 83292) | |
+++ numba/specialize/mathcalls.py (working copy) | |
@@ -38,11 +38,20 @@ | |
is_intrinsic = hasattr(llvm.core, intrinsic_name) | |
return is_intrinsic | |
+if is_win32: | |
+ _MAPPING = { | |
+ 'abs' : 'fabs', | |
+ 'hypot' : '_hypot', | |
+ 'isnan' : '_isnan', | |
+ 'copysign' : '_copysign', | |
+ } | |
+else: | |
+ _MAPPING = { | |
+ 'abs' : 'fabs', | |
+ } | |
def math_suffix(name, type): | |
- if name == 'abs': | |
- name = 'fabs' | |
- | |
+ name = _MAPPING.get(name, name) | |
if type.is_float and type.itemsize == 4: | |
name += 'f' # sinf(float) | |
elif type.is_int and type.itemsize == 16: | |
@@ -61,11 +70,10 @@ | |
return math_suffix(math_name, double) in libc_math_funcs | |
def is_math_function(func_args, py_func): | |
- if len(func_args) == 0 or len(func_args) > 1 or py_func is None: | |
+ if not func_args or py_func is None: | |
return False | |
- type = get_type(func_args[0]) | |
- | |
+ type = mathmodule.largest_type(float_, map(get_type, func_args)) | |
if type.is_array: | |
type = type.dtype | |
valid_type = type.is_float or type.is_int or type.is_complex | |
@@ -104,8 +112,8 @@ | |
def resolve_math_call(call_node, py_func): | |
"Resolve calls to math functions to llvm.log.f32() etc" | |
- signature = call_node.type(call_node.type) | |
- return nodes.MathNode(py_func, signature, call_node.args[0]) | |
+ signature = call_node.type(*[call_node.type] * len(call_node.args)) | |
+ return nodes.MathNode(py_func, signature, call_node.args) | |
def filter_math_funcs(math_func_names): | |
if is_win32: | |
@@ -115,7 +123,8 @@ | |
result_func_names = [] | |
for name in math_func_names: | |
- if getattr(dll, name, None) is not None: | |
+ cname = _MAPPING.get(name, name) | |
+ if getattr(dll, cname, None) is not None: | |
result_func_names.append(name) | |
return result_func_names | |
Index: numba/nodes/callnodes.py | |
=================================================================== | |
--- numba/nodes/callnodes.py (revision 83292) | |
+++ numba/nodes/callnodes.py (working copy) | |
@@ -87,13 +87,13 @@ | |
Represents a high-level call to a math function. | |
""" | |
- _fields = ['arg'] | |
+ _fields = ['args'] | |
- def __init__(self, py_func, signature, arg, **kwargs): | |
+ def __init__(self, py_func, signature, args, **kwargs): | |
super(MathNode, self).__init__(**kwargs) | |
self.py_func = py_func | |
self.signature = signature | |
- self.arg = arg | |
+ self.args = args | |
self.type = signature.return_type | |
class LLVMExternalFunctionNode(ExprNode): |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment