Currently, numba functions display a lot of behaviours expected from true first-class functions. There are, however, 3 behaviours which we see as useful and are currently missing:
- iterating over a sequence of functions
- avoiding re-compilation of higher-order functions
- ability to cache higher-order functions
This is an extremely common pattern:
@njit
def foo(x):
return x + 1
@njit
def foo2(x):
return x + 2
@njit
def bar(fcs, val):
x = 0
for fc in fcs:
x += fc(val)
return x
bar((foo, foo2), 3.5)
At the moment the above does not work in Numba. literal_unroll
could be use for simple cases, but not for something like
@njit
def bar(fcs1, fcs2, val):
x = 0
for fc1, fc2 in zip(fcs1, fcs2):
x += fc1(val)*fc2(val)
return x
bar((foo, foo2), 3.5)
The current implementation of first-class function types would allow the following:
@cfunc(int64(int64))
def foo(x):
return x + 1
@njit
def foo2(x):
return x + 2
bar((foo, foo2), 3.5)
However, this is not practical when the user of the functions is not the creator of them (as in the case of libraries or frameworks). One could compile for one type and the disable compilation.
@njit
def foo(x):
return x + 1
@njit
def foo2(x):
return x + 2
foo.compile(int64(int64))
foo.disable_compile()
bar((foo, foo2), 3.5)
However, this would have a strong side-effect on the original function, breaking future potential calls with other types.
foo(1.2) #fails
One could extract the py_function
@njit
def foo(x):
return x + 1
@njit
def foo2(x):
return x + 2
foo_int = cfunc((int64(int64))(foo.py_func)
bar((foo_int, foo2), 3.5)
but then one must keep track of every version of foo one that has used, or keep recompiling over and over the same function (creating a problem in use case 2).
In summary, while some use cases can be served with the current features, more general cases cannot be accomodated in a smooth way.
Ultimately, the following should happen
@njit
def bar(fcs, val):
x = 0
for fc in fcs:
x += fc(val)
return x
bar((foo, foo2), 3.5)
bar.disable_compile()
bar((foo2, foo), 3.5)
without the need for any annotation.
However, this might not be easy to achieve, since it could require changes to the type inference stage (my guess is that it would need to allow input types to be type variables, and not type constraints with a known value).
If a user annotation is required, then the following is much better for some (many?) use cases:
@njit
def foo(x):
return x + 1
@njit
def foo2(x):
return x + 2
@njit(int64(UniTuple(FunctionType(int64(int64)))))
def bar(fcs, val):
x = 0
for fc in fcs:
x += fc(val)
return x
bar((foo, foo2), 3.5)
bar((foo2, foo1), 3.5) #works
bar((foo2, foo2), 3.5) #works
bar((foo2, foo3), 3.5) #works
The basic logic is that Dispatcher
is an intersection type.
For example, type(foo)=Dispatcher(foo)= int64->int64 & float64->float64 & ...
.
PR #5579 implements the following subtyping rule: if foo: T1 & T2
then foo<:T1
.
In simple terms, if foo can be compiled for int
and float
then foo should
be accepted as a first-class function type by a function that requires int->int
.
This allows a seamless transition from Dispatcher
type to FunctionType
,
which enables first-class behaviour.
As a consequence, and since Numba already implements the subtyping rule for
tuples (S1, S2)<:(T1, T2)
if S1<:T1 and S2<:T2
, then automatically a tuple
of dispatchers would be a subtype of a tuple of FunctionType with any signature
supported by all dispatchers in the tuple.
Thanks to Numba having a solid cast machinery in place, the implementation of this feature only requires 4 lines of code:
def can_convert_to(self, typingctx, other):
if isinstance(other, types.FunctionType):
if self.dispatcher.get_compile_result(other.signature):
return Conversion.safe
Function types also follow the subtyping rule T1->S2 <: S1->T2
if S1<:T1 and S2<:T2
, ie functions are contravariant in their inputs and
covariant for their outputs.
The current PR does not implement that, but a future PR could do it.