Last active
May 12, 2024 12:36
-
-
Save jeffdonahue/12ff1b8e90bed6ed22221cbd9ba49578 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
try: | |
range = xrange # Python 2 | |
except NameError: | |
pass # Python 3 | |
def lazy_product(*iter_funcs, **kwargs): | |
""" | |
If f1, f2, ..., are functions which have no (required) arguments and | |
return iterables, then | |
lazy_product(f1, f2, ..., repeat=k) | |
is equivalent to | |
itertools.product(f1(), f2(), ..., repeat=k); | |
but much faster in certain cases. | |
For example, let f have the following definition: | |
def f(n): | |
def func(): | |
return xrange(n) | |
return func | |
Then, this code: | |
p = itertools.product(*[f(N)() for _ in xrange(M)], repeat=K) | |
first_element = next(p) | |
takes O(NMK) time and memory to execute, whereas | |
p = lazy_product(*[f(N) for _ in xrange(M)], repeat=K) | |
first_element = next(p) | |
is equivalent, and takes just O(MK) time and memory. | |
(Of course, iterating over either result is exactly N^(MK) steps, and each | |
step takes O(1) time; the only difference between itertools.product and | |
lazy_product is at the time of initialization of the iterable p | |
(including the call to next(p) to get the first element, as shown above). | |
itertools.product's O(N) speed/memory overhead results from its saving the | |
full result of xrange(N) as a list (or similar data structure) in memory. | |
This is necessary as itertools.product takes iterables as input, and it is | |
not generally possible to "reset" an iterator, so all of its values | |
instead need to be stored. So, the input to lazy_product is an iterable | |
of *functions* returning iterables, rather than the iterables themselves, | |
allowing for repeated iteration over each iterable (by calling iter_func | |
again when we reach the end of the iterable that iter_func created on | |
the previous call). | |
Inputs: | |
- iter_funcs: functions with no (required) arguments that create and | |
return an iterable. Each function is assumed to be be deterministic -- | |
i.e., return an identical iterable on each call. (Otherwise, the | |
behavior of lazy_product is undefined.) | |
- kwargs: a dict which is either empty or contains only the key `repeat`, | |
with an integer value. In Python 3, the function header could (much | |
more cleanly) be written as: | |
def lazy_product(*iter_funcs, repeat=1): | |
and the first two lines of ugly parsing code could be dropped. | |
Returns: | |
an iterator over the Cartesian product of the iterables returned | |
by the elements of iter_funcs -- equivalent to: | |
return itertools.product(*(f() for f in iter_funcs), **kwargs) | |
""" | |
repeat = kwargs.pop('repeat', 1) | |
if kwargs: raise ValueError('unknown kwargs: %s' % kwargs.keys()) | |
iters = [iter(f()) for _ in range(repeat) for f in iter_funcs] | |
values = [next(i) for i in iters] | |
while True: | |
yield tuple(values) | |
for index in reversed(range(len(iters))): | |
try: | |
values[index] = next(iters[index]) | |
break | |
except StopIteration: | |
iters[index] = iter(iter_funcs[index % len(iter_funcs)]()) | |
values[index] = next(iters[index]) | |
else: return | |
from functools import partial | |
def lazy_product_func(*a, **k): | |
return partial(lazy_product, *a, **k) | |
def range_func(*a, **k): | |
return partial(range, *a, **k) | |
xrange_func = range_func | |
if __name__ == '__main__': | |
import itertools | |
def test_equivalence(*iter_funcs, **kwargs): | |
lazy_result = lazy_product(*iter_funcs, **kwargs) | |
iters = (f() for f in iter_funcs) | |
itertools_result = itertools.product(*iters, **kwargs) | |
return list(lazy_result) == list(itertools_result) | |
assert test_equivalence() | |
assert test_equivalence(repeat=0) | |
assert test_equivalence(repeat=1) | |
assert test_equivalence(repeat=2) | |
assert test_equivalence(range_func(0)) | |
assert test_equivalence(range_func(0), repeat=2) | |
assert test_equivalence(range_func(2)) | |
assert test_equivalence(range_func(2), repeat=2) | |
assert test_equivalence(range_func(2), range_func(3)) | |
assert test_equivalence(range_func(2), range_func(0), range_func(3)) | |
assert test_equivalence(range_func(2), range_func(0), range_func(3), | |
repeat=2) | |
assert test_equivalence(range_func(2), range_func(3), repeat=2) | |
assert test_equivalence(range_func(2), range_func(3), repeat=2) | |
assert test_equivalence(range_func(3), range_func(2, 7), repeat=0) | |
assert test_equivalence(range_func(3), range_func(2, 7), repeat=4) | |
print('Test passed!') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
In [1]: import itertools; from lazy_product import * | |
In [2]: def f(n): | |
def func(): | |
return xrange(n) | |
return func | |
In [3]: K=1; M=10; N=1000000; | |
In [4]: itertools_input = [f(N)() for _ in xrange(M)] | |
In [5]: lazy_input = [f(N) for _ in xrange(M)] | |
In [6]: %timeit p = itertools.product(*itertools_input, repeat=K); next(p) | |
10 loops, best of 3: 155 ms per loop | |
In [7]: %timeit p = lazy_product(*lazy_input, repeat=K); next(p) | |
100000 loops, best of 3: 6.21 µs per loop | |
In [8]: N*=10 | |
In [9]: itertools_input = [f(N)() for _ in xrange(M)] | |
In [10]: lazy_input = [f(N) for _ in xrange(M)] | |
In [11]: %timeit p = itertools.product(*itertools_input, repeat=K); next(p) | |
1 loops, best of 3: 1.03 s per loop | |
In [12]: %timeit p = lazy_product(*lazy_input, repeat=K); next(p) | |
100000 loops, best of 3: 6.32 µs per loop |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment