This is an attempt for extending broadcasting rules in Numpy without losing the benefits of vectorization by converting new rules to preexisting ones.
The broadcasting policy is taken from APL standards according to the following ideas:
- a subset of nested arrays from APL is implemented here;
- an
axis
parameter is implemented here.
For instance the two following arrays can't be broadcast together:
1 2 3 4 and 1 2
5 6 7 8 3 4
But you can enclose parts of the first array in order to turn them into "scalars" as:
+-----+-----+
| 1 2 | 3 4 |
+-----+-----+
| 5 6 | 7 8 |
+-----+-----+
and make it compliant with a (2,2)
shape by giving to it the following structure: (2, 2, (2,))
with the following call:
apl_broadcast(np.add,
[[1,2],[3,4]], [[1,2,3,4],[5,6,7,8]],
structures=(None, (2,2,(2,))))
Deeper levels in nesting arrays may be achieved as long as the total number of elements is preserved and by adding each time a child tuple as the last element of the parent tuple.
A list of axes can also be provided as in the APL standard (except for the offset of course, which is 1 in APL and which has be turned to 0 here):
apl_broadcast(np.add, [1,10], [[2,2],[3,3]], axis=[0])
apl_broadcast(np.add, [1,10], [[2,2],[3,3]], axis=[1])
An additional parameter reshape
has been implemented; when set to True
, the final shape is kept after the computation. By default however, it is set to False
and the result should be reshaped to the initial shape of the greatest array. It is not fully decided yet which shape should be kept in the latter case when both arrays have the same size but not the same shape; maybe this parameter should be removed anyway (and set to True
by default).
Extensive tests should be performed on this piece of code by APL gurus in order to look for issues. A link to this page will be posted to Hacker News and to Reddit.
import numpy as np
def apl_broadcast(func, left, right, axis=[],
structures=(None, None),
reshape=False):
if not isinstance(left, np.ndarray):
left = np.array(left)
if not isinstance(right, np.ndarray):
right = np.array(right)
lstruct, rstruct = structures
if lstruct:
l = []
while lstruct:
last = lstruct[-1]
if isinstance(last, tuple) or isinstance(last, list):
l.append(lstruct[:-1])
lstruct = last
else:
l.append(lstruct)
lstruct = None
lstruct = l
else:
lstruct = [left.shape]
if rstruct:
l = []
while rstruct:
last = rstruct[-1]
if isinstance(last, tuple) or isinstance(last, list):
l.append(rstruct[:-1])
rstruct = last
else:
l.append(rstruct)
rstruct = None
rstruct = l
else:
rstruct = [right.shape]
ln, rn = len(lstruct), len(rstruct)
ls, rs = tuple([]), tuple([])
for i in range(max(ln, rn)):
if i < ln and lstruct[i]:
if i < rn and rstruct[i]:
if axis:
if axis[-1] < len(lstruct[i]):
A = tuple(lstruct[i][j] for j in axis)
if A == rstruct[i]:
S = [1]*len(lstruct[i])
for j in axis: S[j] = lstruct[i][j]
ls += lstruct[i]
rs += tuple(S)
axis = None
continue
if axis[-1] < len(rstruct[i]):
A = tuple(rstruct[i][j] for j in axis)
if A == lstruct[i]:
S = [1]*len(rstruct[i])
for j in axis: S[j] = rstruct[i][j]
ls += tuple(S)
rs += rstruct[i]
axis = None
continue
raise ValueError("invalid axis " + str(axis))
if lstruct[i] == rstruct[i]:
ls += lstruct[i]
rs += rstruct[i]
elif all(x==1 for x in lstruct[i]):
ls += (1,) * len(rstruct[i])
rs += rstruct[i]
elif all(x==1 for x in rstruct[i]):
ls += lstruct[i]
rs += (1,) * len(lstruct[i])
else:
def pretty_struct(a):
s = ""
for k in a:
s += "(" + ("".join([str(x) + ", " for x in k]))
if len(a[-1]) > 1: s = s[:-1]
return s[:-1] + (")" * len(a))
raise ValueError(
"operands could not be broadcast together with shapes "
+ pretty_struct(lstruct) + " " + pretty_struct(rstruct))
else:
ls += lstruct[i]
rs += (1,) * len(lstruct[i])
axis = None
else:
if i < rn and rstruct[i]:
rs += rstruct[i]
ls += (1,) * len(rstruct[i])
axis = None
if reshape:
return func(left.reshape(ls), right.reshape(rs))
# TODO both size equal and different shapes?
# the best policy has to be investigated:
# either the simplest of initial shapes or forcing "reshape" as above
if left.size > right.size:
return func(left.reshape(ls), right.reshape(rs)).reshape(left.shape)
else:
return func(left.reshape(ls), right.reshape(rs)).reshape(right.shape)