Last active
March 29, 2020 17:50
-
-
Save cwindolf/8d3737cb6fd589efb251fee536df845c to your computer and use it in GitHub Desktop.
Syntactic sugar for reductions in numpy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
class _redder: | |
"""Syntactic sugar for reductions in numpy. | |
Users interact with the global instance `rdr` defined below rather | |
than using this class itself. | |
Probably a dumb idea but whatever. Basically a sugar to batch out | |
to `np.sum`, `np.prod`, `np.linalg.multi_dot`, etc, or | |
`ufunc.reduce`. | |
Notice matrix mult is sort of an odd man out in that `rdr @ <list | |
of matrices>` works but `rdr(np.matmul) << ...` throws due to some | |
shape problems that I have not figured out yet. | |
Examples: | |
========= | |
``` | |
>>> rdr + [0, 1, 2] | |
3 | |
>>> rdr * [0, 1, 2] | |
0 | |
>>> rdr | [0, 1, 2] | |
True | |
>>> rdr & [0, 1, 2] | |
False | |
>>> a = np.eye(3); b = 2 * a | |
>>> rdr @ [a, a, b, b] | |
array([[4., 0., 0.], | |
[0., 4., 0.], | |
[0., 0., 4.]]) | |
>>> def f(a, b): return a - b | |
>>> rdr(f) << [0, 1, 2] | |
-3 | |
>>> rdr(f, initial=-100) << [0, 1, 2] | |
-103 | |
>>> rdr(f) << [1, 2, 3] | |
-4 | |
>>> rdr(f, initial=0) << [1, 2, 3] | |
-6 | |
``` | |
""" | |
def __init__(self, fn=None, **kwargs): | |
self.fn = fn | |
self.kwargs = kwargs | |
def __call__(self, fn=None, **kwargs): | |
""" | |
`rdr(fn) << lst` will reduce `lst` using binary op `fn` | |
`rdr(**kwargs)` applies kwargs to the reducer that gets called | |
(`fn.reduce`, `sum`, `prod`, `any`, etc) | |
These two behaviors can be stacked with `rdr(fn, **kwargs)`. | |
""" | |
return _redder(fn=fn, **kwargs) | |
def __add__(self, other): | |
"""`rdr + [0, 1, 2] => 3""" | |
assert self.fn is None | |
return np.sum(other, **self.kwargs) | |
def __mul__(self, other): | |
"""`rdr * [0, 1, 2] => 0""" | |
assert self.fn is None | |
return np.prod(other, **self.kwargs) | |
def __or__(self, other): | |
assert self.fn is None | |
return np.any(other, **self.kwargs) | |
def __and__(self, other): | |
assert self.fn is None | |
return np.all(other, **self.kwargs) | |
def __matmul__(self, other): | |
"""`rdr @ [mat1, mat2, mat3] => mat1 @ mat2 @ mat3""" | |
assert self.fn is None | |
return np.linalg.multi_dot(other) | |
def __lshift__(self, other): | |
"""rdr(f) << [0, 1, 2] = reduce(f, [0, 1, 2])""" | |
assert self.fn is not None | |
# self.fn had better be a binary operation. | |
return np.frompyfunc(self.fn, 2, 1).reduce(other, **self.kwargs) | |
# Users interact with this instance. | |
rdr = _redder() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment