Created
October 26, 2021 15:50
-
-
Save brandonwillard/5f870e349e0e7f2be2ac9b57e18bcb8a to your computer and use it in GitHub Desktop.
Calling `PyUFuncObject`'s `reduce` method from Numba
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numba as nb | |
import numpy as np | |
from numba.core.datamodel.models import StructModel | |
from numba.core.datamodel.registry import register_default | |
from numba.core.extending import intrinsic, overload, overload_method | |
@nb.njit | |
def add(x, y): | |
return x + y | |
custom_vectorize = nb.vectorize([], identity=None, target="cpu") | |
vec_add = custom_vectorize(add) | |
# | |
# [[file:numba/np/ufunc/_internal.c::} PyDUFuncObject;][PyDUFuncObject]] | |
# | |
@register_default(nb.np.ufunc.dufunc.DUFunc) | |
class PyDUFuncModel(StructModel): | |
"""A model for `DUFunc`.""" | |
_element_type = NotImplemented | |
def __init__(self): | |
members = [ | |
("_dispatcher", nb.types.pyobject), | |
("ufunc", nb.types.pyobject), | |
("_keepalive", nb.types.pyobject), | |
] | |
super(PyDUFuncModel, self).__init__(members) | |
# | |
# Do the same for the `PyUFuncObject` returned by `PyDUFuncModel.ufunc`? | |
# | |
# [[file:../../../../apps/anaconda3/envs/numba-env/lib/python3.7/site-packages/numpy/core/include/numpy/ufuncobject.h::} PyUFuncObject;][PyUFuncObject]] | |
# | |
# @register_default(?) | |
# class PyUFuncModel(StructModel): | |
# _element_type = NotImplemented | |
# | |
# def __init__(self): | |
# members = [ | |
# ("ptr", nb.types.pyobject), | |
# ("obj", nb.types.pyobject), | |
# ] | |
# super(PyUFuncModel, self).__init__(members) | |
@intrinsic | |
def intr_reduce(typcontext, ft, xt, yt, axist): | |
sig = nb.types.int64(ft, xt, yt, axist) | |
def codegen(context, builder, sig, args): | |
ft = sig.args[0] | |
f_ir, x_ir, y_ir, axis_ir = args | |
# Create a usable reference to the underlying `PyDUFuncObject`? | |
# fn = cgutils.create_struct_proxy(ft)(context, builder) | |
# | |
# TODO: Call reduce from one of these references! | |
# | |
breakpoint() | |
return sig, codegen | |
@overload_method(nb.types.Function, "reduce") | |
def dufunc_reduce(ft, xt, yt, axist): | |
if isinstance(ft.typing_key, nb.np.ufunc.dufunc.DUFunc): | |
def _reduce_impl(ft, xt, yt, axist): | |
return intr_reduce(ft, xt, yt, axist) | |
return _reduce_impl | |
@nb.njit | |
def test_fn(x, y): | |
return vec_add.reduce(x, y, 0) | |
test_fn(np.arange(10), np.arange(10) * 2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment