#!/usr/bin/env sed -re s|^|\x20\x20\x20\x20| -e s|^\x20{4}\x23\x23{(.*)$|<details><summary>\1</summary>\n| -e s|^\x20{4}\x23\x23}$|\n</details>| -e s|^\x20{4}\x23\x23\x20?|| -e s|\x0c|\x20|
    
<details><summary>license, imports</summary>

    # Yannakakis.py by Paul Khuong
    #
    # To the extent possible under law, the person who associated CC0 with
    # Yannakakis.py has waived all copyright and related or neighboring rights
    # to Yannakakis.py.
    #
    # You should have received a copy of the CC0 legalcode along with this
    # work.  If not, see <http://creativecommons.org/publicdomain/zero/1.0/>.
    
    import collections.abc
    

</details>
    
# Linear-time analytical queries in plain Python

<small>I didn't know [Mihalis Yannakakis is about to turn 70](https://mihalisfest.cs.columbia.edu), but that's a fun coïncidence.</small>

This short hack shows how [Yannakakis's algorithm](https://pages.cs.wisc.edu/~paris/cs784-f19/lectures/lecture4.pdf)
lets us implement linear-time (wrt the number of input data rows)
analytical queries in regular programming languages, without
offloading joins to a specialised database query language, thus
avoiding the associated impedance mismatch.  There are restrictions
on the queries we can express -- Yannakakis's algorithm relies on a
hypertree width of 1, and on having hypertree decomposition as a witness
-- but that's kind of reasonable: a fractional hypertree width > 1
would mean there are databases for which the intermediate results
could superlinearly larger than the input database due to the [AGM bound](https://arxiv.org/abs/1711.03860).
The hypertree decomposition witness isn't a burden either:
structured programs naturally yield a hypertree decomposition,
unlike more declarative logic programs that tend to hide the
structure implicit in the programmer's thinking.

The key is to mark function arguments as either negative (true
inputs) or positive (possible interesting input values derived from
negative arguments).  In this hack, closed over values are
positive, and the only negative argument is the current data row.

We also assume that these functions are always used in a map/reduce
pattern, and thus we only memoise the result of
`map_reduce(function, input)`, with a group-structured reduction:
the reduce function must be associative and commutative, and there
must be a zero (neutral) value.

With these constraints, we can express joins in natural Python
without incurring the poor runtime scaling of the nested loops we
actually wrote.  This Python file describes the building blocks
to handle aggregation queries like the following

    >>> id_skus = [(1, 2), (2, 2), (1, 3)]
    >>> sku_costs = [(1, 10), (2, 20), (3, 30)]
    >>> def sum_odd_or_even_skus(mod_two):
    ...     @map_reduce.over(id_skus, Sum())
    ...     def count_if_mod_two(id_sku):
    ...         id, sku = id_sku
    ...         if id % 2 == mod_two:
    ...             @map_reduce.over(sku_costs, Min(0))
    ...             def min_cost(sku_cost):
    ...                 if sku_cost[0] == sku:
    ...                     return Min(sku_cost[1])
    ...             return Sum(min_cost)
    ...     return count_if_mod_two
    ...

with linear scaling in the length of `id_skus` and `sku_costs`, and
caching for similar queries.

At a small scale, everything's fast.

    >>> begin = time.time(); print(sum_odd_or_even_skus(0)); print(time.time() - begin)
    20
    0.0007169246673583984
    >>> begin = time.time(); print(sum_odd_or_even_skus(1)); print(time.time() - begin)
    50
    0.0002627372741699219

As we increase the scale by 1000x for both input lists, the runtime
scales (sub) linearly for the first query, and is unchanged for the
second:

    >>> id_skus = id_skus * 1000
    >>> sku_costs = sku_costs * 1000
    >>> begin = time.time(); print(sum_odd_or_even_skus(0)); print(time.time() - begin)
    20000
    0.09455370903015137
    >>> begin = time.time(); print(sum_odd_or_even_skus(1)); print(time.time() - begin)
    50000
    0.00025773048400878906

This still pretty much holds up when we multiply by another factor of 100:

    >>> id_skus = id_skus * 100
    >>> sku_costs = sku_costs * 100
    >>> begin = time.time(); print(sum_odd_or_even_skus(0)); print(time.time() - begin)
    2000000
    6.946590185165405
    >>> begin = time.time(); print(sum_odd_or_even_skus(1)); print(time.time() - begin)
    5000000
    0.00025200843811035156

The magic behind the curtains is memoisation (unsurprisingly), but
a special implementation that can share work for similar closures:
the memoisation key consists of the function *without closed over
bindings* and the call arguments, while the memoised value is a
data structure from the tuple of closed over values to the
`map_reduce_over` output.

This concrete representation of the function as a branching program
is the core of Yannakakis's algorithm: we'll iterate over each
datum in the input, run the function on it with logical variables
instead of the closed over values, and generate a mapping from
closed over values to result for all non-zero results.  We'll then
merge the mappings for all input data together (there is no natural
ordering here, hence the group structure).

The output data structure controls the join we can implement.  We
show how a simple stack of nested key-value mappings handles
equijoins, but a [k-d range tree](https://dl.acm.org/doi/10.1145/356789.356797)
would handle inequalities, i.e., "theta" joins (the rest of the
machinery already works in terms of less than/greater than
constraints).

As long as we explore a bounded number of paths for each datum and
a bounded number of `function, input` cache keys, we'll spend a
bounded amount of time on each input datum, and thus linear time
total.  The magic of Yannakakis's algorithm is that this works even
when there are nested `map_reduce` calls, which would naïvely
result in polynomial time (degree equal to the nesting depth).
    
     
<details><summary><h2>Memoising through a Python function's closed over values</h2></summary>


Even if functions were hashable and comparable for extensional
equality, directly using closures as memoisation keys in calls like
`map_reduce(function, input)` would result in superlinear runtime
for nested `map_reduce` calls.

This `extract_function_state` accepts a function (with or without
closed over state), and returns four values:
  1. The underlying code object
  2. The tuple of closed over values (current value for mutable cells)
  3. A function to rebind the closure with new closed over values
  4. The name of the closed over bindings

The third return value, the `rebind` function, silently fails on
complicated cases; this is a hack, after all.  In short, it only
handles closing over immutable atomic values like integers or
strings, but not, e.g., functions (yet), or mutable bindings.
    
    
    def extract_function_state(function):
        """Accepts a function object and returns information about it: a
        hash key for the object, a tuple of closed over values, a function
        to return a fresh closure with a different tuple of closed over
        values, and a closure of closed over names
    
        >>> def test(x): return lambda y: x == y
        >>> extract_function_state(test)[1]
        ()
        >>> extract_function_state(test)[3]
        ()
        >>> fun = test(4)
        >>> extract_function_state(fun)[1]
        (4,)
        >>> extract_function_state(fun)[3]
        ('x',)
        >>>
        >>> fun(4)
        True
        >>> fun(5)
        False
        >>> rebind = extract_function_state(fun)[2]
        >>> rebound_4, rebound_5 = rebind([4]), rebind([5])
        >>> rebound_4(4)
        True
        >>> rebound_4(5)
        False
        >>> rebound_5(4)
        False
        >>> rebound_5(5)
        True
        """
        code = function.__code__
        names = code.co_freevars
    
        if function.__closure__ is None:  # Toplevel function
            assert names == ()
    
            def rebind(values):
                if len(values) != 0:
                    raise RuntimeError(
                        f"Values must be empty for toplevel function. values={values}"
                    )
                return function
    
            return code, (), rebind, names
    
        closure = tuple(cell.cell_contents for cell in function.__closure__)
        assert len(names) == len(closure), (closures, names)
    
        # TODO: rebind recursively (functions are also cells)
        def rebind(values):
            if len(values) != len(names):
                raise RuntimeError(
                    f"Values must match names. names={names} values={values}"
                )
            return function.__class__(
                code,
                function.__globals__,
                function.__name__,
                function.__defaults__,
                tuple(
                    cell.__class__(value)
                    for cell, value in zip(function.__closure__, values)
                ),
            )
    
        return code, closure, rebind, names
    
    

</details>
    
     
## Logical variables for closed-over values

We wish to enumerate the support of a function call (parameterised
over closed over values), and the associated result value.  We'll
do that by rebinding the closure to point at instances of
`OpaqueValue` and enumerating all the possible constraints on these
`OpaqueValue`s.  These `OpaqueValue`s work like logical variables
that let us run a function in reverse: when get a non-zero
(non-None) return value, we look at the accumulated constraint set
on the opaque values and use them to update the data representation
of the function's result (we assume that we can represent the
constraints on all `OpaqueValue`s in our result data structure).

Currently, we only support nested dictionaries, so each
`OpaqueValue`s must be either fully unconstrained (wildcard that
matches any value), or constrained to be exactly equal to a value.
There's no reason we can't use k-d range trees though, and it's not
harder to track a pair of bounds (lower and upper) than a set of
inequalities, so we'll handle the general ordered `OpaqueValue`
case.

In the input program (the query), we assume closed over values are
only used for comparisons (equality, inequality, relational
operators, or conversion to bool, i.e., non-zero testing).  Knowing
the result of each (non-redundant) comparison tightens the range
of potential values for the `OpaqueValue`... eventually down to a
single point value that our hash-based indexes can handle.

Of course, if a comparison isn't redundant, there are multiple
feasible results, so we need an external oracle to pick one.  An
external caller is responsible for injecting its logic as
`OpaqueValue.CMP_HANDLER`, and driving the exploration of the
search space.

N.B., the set of constraints we can handle is determined by the
ground data structure to represent finitely supported functions.
    
    
    class OpaqueValue:
        """An opaque value is a one-dimensional range of Python values,
        represented as a lower and an upper bound, each of which is
        potentially exclusive.
    
        `OpaqueValue`s are only used in queries for comparisons with
        ground values.  All comparisons are turned into three-way
        `__cmp__` calls; non-redundant `__cmp__` calls (which could return
        more than one value) are resolved by calling `CMP_HANDLER`
        and tightening the bound in response.
    
        >>> x = OpaqueValue("x")
        >>> x == True
        True
        >>> 1 if x else 2
        1
        >>> x.reset()
        >>> ### Not supported by our index data structure (yet)
        >>> # >>> x > 4
        >>> # False
        >>> # >>> x < 4
        >>> # False
        >>> # >>> x == 4
        >>> # True
        >>> # >>> x < 10
        >>> # True
        >>> # >>> x >= 10
        >>> # False
        >>> # >>> x > -10
        >>> # True
        >>> # >>> x <= -10
        >>> # False
        """
    
        # Resolving function for opaque values
        CMP_HANDLER = lambda opaque, value: 0
    
        def __init__(self, name):
            self.name = name
            # Upper and lower bounds
            self.lower = self.upper = None
            self.lowerExcl = self.upperExcl = True
    
        def __str__(self):
            left = "(" if self.lowerExcl else "["
            right = ")" if self.upperExcl else "]"
            return f"<OpaqueValue {self.name} {left}{self.lower}, {self.upper}{right}>"
    
        def reset(self):
            """Clears all accumulated constraints on this `OpaqueValue`."""
            self.lower = self.upper = None
            self.lowerExcl = self.upperExcl = True
    
        def indefinite(self):
            """Returns whether this `OpaqueValue` is still unconstrained."""
            return self.lower == None and self.upper == None
    
        def definite(self):
            """Returns whether is `OpaqueValue` is constrained to an exact value."""
            return (
                self.lower == self.upper
                and self.lower is not None
                and not self.lowerExcl
                and not self.upperExcl
            )
    
        def value(self):
            """Returns the exact value for this `OpaqueValue`, assuming there is one."""
            return self.lower
    
        def _contains(self, value, strictly):
            """Returns whether this `OpaqueValue`'s range includes `value`
            (maybe `strictly` inside the range).
    
            The first value is the containment truth value, and the second
            is the forced `__cmp__` value, if the `value` is *not*
            (strictly) contained in the range.
            """
            if self.lower is not None:
                if self.lower > value:
                    return False, 1
                if self.lower == value and (strictly or self.lowerExcl):
                    return False, 1
            if self.upper is not None:
                if self.upper < value:
                    return False, -1
                if self.upper == value and (strictly or self.upperExcl):
                    return False, -1
            return True, 0 if self.definite() else None
    
        def contains(self, value, strictly=False):
            """Returns whether `values` is `strictly` contained in the
            `OpaqueValue`'s range.
            """
            try:
                return self._contains(value, strictly)[0]
            except TypeError:
                return False
    
        def potential_mask(self, other):
            """Returns the set of potential `__cmp__` values for `other`
            that are compatible with the current range: bit 0 is set if
            `-1` is possible, bit 1 if `0` is possible, and bit 2 if `1`
            is possible.
            """
            if not self.contains(other):
                return 0
            if self.contains(other, strictly=True):
                return 7
            if self.definite() and self.value == other:
                return 2
            # We have a non-strict inclusion, and inequality.
            if self.lower == other and self.lowerExcl:
                return 6
            assert self.upper == other and self.upperExcl
            return 3
    
        def __cmp__(self, other):
            """Three-way comparison between this `OpaqueValue` and `other`.
    
            When the result is known from the current bound, we just return
            that value.  Otherwise, we ask `CMP_HANDLER` what value to return
            and update the bound accordingly.
            """
            if isinstance(other, OpaqueValue) and self.definite() and not other.definite():
                # If we have a definite value and `other` is an indefinite
                # `OpaqueValue`, flip the comparison order to let the `other`
                # argument be a ground value.
                return -other.__cmp__(self.value())
    
            if isinstance(other, OpaqueValue) and not other.definite():
                raise RuntimeError(
                    f"OpaqueValue may only be compared with ground values. self={self} other={other}"
                )
            if isinstance(other, OpaqueValue):
                other = other.value()  # Make sure `other` is a ground value
    
            if other is None:
                # We use `None` internally, and it doesn't compare well
                raise RuntimeError("OpaqueValue may not be compared with None")
    
            compatible, order = self._contains(other, False)
            if order is not None:
                return order
            order = OpaqueValue.CMP_HANDLER(self, other)
            if order < 0:
                self._add_bound(upper=other, upperExcl=True)
            elif order == 0:
                self._add_bound(lower=other, lowerExcl=False, upper=other, upperExcl=False)
            else:
                self._add_bound(lower=other, lowerExcl=True)
            return order
    
        def _add_bound(self, lower=None, lowerExcl=False, upper=None, upperExcl=False):
            """Updates the internal range for this new bound."""
            assert lower is None or self.contains(lower, strictly=lowerExcl)
            assert upper is None or self.contains(upper, strictly=upperExcl)
            if lower is not None:
                self.lower = lower
                self.lowerExcl = lowerExcl
    
            assert upper is None or self.contains(upper, strictly=upperExcl)
            if upper is not None:
                self.upper = upper
                self.upperExcl = upperExcl
    
        def __bool__(self):
            return self != 0
    
        def __eq__(self, other):
            return self.__cmp__(other) == 0
    
        def __ne__(self, other):
            return self.__cmp__(other) != 0
    
        # No other comparator because we don't index ranges
        # (no range tree).
    
        # def __lt__(self, other):
        #     return self.__cmp__(other) < 0
    
        # def __le__(self, other):
        #     return self.__cmp__(other) <= 0
    
        # def __gt__(self, other):
        #     return self.__cmp__(other) > 0
    
        # def __ge__(self, other):
        #     return self.__cmp__(other) >= 0
    
     
## Depth-first exploration of a function call's support

We assume a `None` result represents a zero value wrt the aggregate
merging function (e.g., 0 for a sum).  For convenience, we also treat
tuples and lists of `None`s identically.

We simply maintain a stack of `CMP_HANDLER` calls, where each entry
in the stack consists of an `OpaqueValue` and bitset of `CMP_HANDLER`
results still to explore (-1, 0, or 1).  This stack is filled on demand,
and `CMP_HANDLER` returns the first result allowed by the bitset.

Once we have a result, we tweak the stack to force depth-first
exploration of a different part of the solution space: we drop
the first bit in the bitset of results to explore, and drop the
entry wholesale if the bitset is now empty (all zero).  When
this tweaking leaves an empty stack, we're done.

This ends up enumerating all the paths through the function call
with a non-recursive depth-first traversal.

We then do the same for each datum in our input sequence, and merge
results for identical keys together.
    
    
    def is_zero_result(value):
        """Checks if `value` is a "zero" aggregate value: either `None`,
        or an iterable of all `None`.
    
        >>> is_zero_result(None)
        True
        >>> is_zero_result(False)
        False
        >>> is_zero_result(True)
        False
        >>> is_zero_result(0)
        False
        >>> is_zero_result(-10)
        False
        >>> is_zero_result(1.5)
        False
        >>> is_zero_result("")
        False
        >>> is_zero_result("asd")
        False
        >>> is_zero_result((None, None))
        True
        >>> is_zero_result((None, 1))
        False
        >>> is_zero_result([])
        True
        >>> is_zero_result([None])
        True
        >>> is_zero_result([None, (None, None)])
        False
        """
        if value is None:
            return True
        if isinstance(value, (tuple, list)):
            return all(item is None for item in value)
        return False
    
    
    def enumerate_opaque_values(function, values):
        """Explores the set of `OpaqueValue` constraints when calling
        `function`.
    
        Enumerates all constraints for the `OpaqueValue` instances in
        `values`, and yields a pair of equality constraints for the
        `value` and the corresponding result, for all non-zero results.
    
        This essentially turns `function()` into a branching program on
        `values`.
    
        >>> x, y = OpaqueValue("x"), OpaqueValue("y")
        >>> list(enumerate_opaque_values(lambda: 1 if x == 0 else (2 if x == 1 and y == 2 else None), [x, y]))
        [((0, None), 1), ((1, 2), 2)]
    
        """
        explorationStack = []  # List of (value, bitmaskOfCmp)
        while True:
            for value in values:
                value.reset()
    
            stackIndex = 0
    
            def handle(value, other):
                nonlocal stackIndex
                if len(explorationStack) == stackIndex:
                    explorationStack.append((value, value.potential_mask(other)))
    
                expectedValue, mask = explorationStack[stackIndex]
                assert value is expectedValue
                assert mask != 0
    
                if (mask & 1) != 0:
                    ret = -1
                elif (mask & 2) != 0:
                    ret = 0
                elif (mask & 4) != 0:
                    ret = 1
                else:
                    assert False, f"bad mask {mask}"
    
                stackIndex += 1
                return ret
    
            OpaqueValue.CMP_HANDLER = handle
            result = function()
            if not is_zero_result(result):
                for value in values:
                    assert (
                        value.definite() or value.indefinite()
                    ), f"partially constrained {value} temporarily unsupported"
                yield (tuple(key.value() if key.definite() else None for key in values),
                       result)
    
            # Drop everything that was fully explored, then move the next
            # top of stack to the next option.
            while explorationStack:
                value, mask = explorationStack[-1]
                assert 0 <= mask < 8
    
                mask &= mask - 1  # Drop first bit
                if mask != 0:
                    explorationStack[-1] = (value, mask)
                    break
                explorationStack.pop()
            if not explorationStack:
                break
    
    
    def enumerate_supporting_values(function, args):
        """Lists the bag of mapping from closed over values to non-zero result,
        for all calls `function(arg) for args in args`.
    
        >>> def count_eql(needle): return lambda x: 1 if x == needle else None
        >>> list(enumerate_supporting_values(count_eql(4), [1, 2, 4, 4, 2]))
        [((1,), 1), ((2,), 1), ((4,), 1), ((4,), 1), ((2,), 1)]
        """
        _, _, rebind, names = extract_function_state(function)
        values = [OpaqueValue(name) for name in names]
        reboundFunction = rebind(values)
        for arg in args:
            yield from enumerate_opaque_values(lambda: reboundFunction(arg), values)
    
     
## Type driven merges

The interesting part of map/reduce is the reduction step.  While
some like to use first-class functions to describe reduction, in my
opinion, it often makes more sense to define reduction at the type
level: it's essential that merge operators be commutative and
associative, so isolating the merge logic in dedicated classes
makes sense to me.

This file defines a two trivial mergeable value type, `Sum` and
`Min`, but we could have different ones, e.g., hyperloglog unique
counts, or streaming statistical moments... or even a list of row
ids.
    
    
    class Sum:
        """A counter for summed values."""
    
        def __init__(self, value=0):
            self.value = value
    
        def merge(self, other):
            assert isinstance(other, Sum)
            self.value += other.value
    
    
    class Min:
        """A running `min` value tracker."""
    
        def __init__(self, value=None):
            self.value = value
    
        def merge(self, other):
            assert isinstance(other, Min)
            if not self.value:
                self.value = other.value
            elif other.value:
                self.value = min(self.value, other.value)
    
     
## Nested dictionary with wildcard

There's a direct relationship between the data structure we use to
represent the result of function calls as branching functions, and
the constraints we can support on closed over values for non-zero
results.

In a real implementation, this data structure would host most of
the complexity: it's the closest thing we have to indexes.

For now, support equality *with ground value* as our only
constraint.  This means the only two cases we must look for at
each level are an exact match, or a wildcard match.

We do have to check for both cases at each level, so the worst-case
complexity for lookups is exponential in the depth (number of join
variables).  That's actually reasonable because we don't expect too
many join variables (compare to [range trees](https://en.wikipedia.org/wiki/Range_tree#Range_queries)
that are also exponential in the number of dimensions... but with
a base of log(n) instead of 2).

A real implementation could maybe save work by memoising merged
results for internal subtrees... it's tempting to broadcast
wildcard values to the individual keyed entries, but I think that
might explode the time complexity of our pre-computation phase.
    
    class NestedDictLevel:
        """One level in a nested dictionary index.  We may have a value
        for everything (leaf node), or a key-value dict *and a wildcard
        entry* for a specific index in the tuple key.
        """
    
        def __init__(self, depth):
            self.depth = depth
            self.value = None
            self.wildcard = None
            self.dict = dict()
    
        def visit(self, keys, visitor):
            """Passes the values for `keys` to `visitor`."""
            if self.value is not None:
                assert self.wildcard is None and not self.dict
                visitor(self.value)
    
            if self.depth >= len(keys):
                return
    
            if self.wildcard is not None:
                self.wildcard.visit(keys, visitor)
    
            next = self.dict.get(keys[self.depth], None)
            if next is not None:
                next.visit(keys, visitor)
    
        def set(self, keys, mergeFunction, depth=0):
            """Sets the value for `keys` in this level."""
            assert depth <= len(keys)
            if depth == len(keys):  # Leaf
                self.value = mergeFunction(self.value)
                return
    
            assert self.depth == depth
            key = keys[depth]
            
            if key is None:
                if self.wildcard is None:
                    self.wildcard = NestedDictLevel(depth + 1)
                dst = self.wildcard
            else:
                dst = self.dict.get(key)
                if dst is None:
                    dst = NestedDictLevel(depth + 1)
                    self.dict[key] = dst
            dst.set(keys, mergeFunction, depth + 1)
    
    
    class NestedDict:
        """A nested dict of a given `depth` maps tuples of `depth` keys to
        a value.  Each `NestedDictLevel` handles a different level.
        """
    
        def __init__(self, length):
            self.top = NestedDictLevel(0)
            self.length = length
    
        def visit(self, keys, visitor):
            """Gets the value associated with `keys`, or `default` if None."""
            assert len(keys) == self.length
            assert all(key is not None for key in keys)
            self.top.visit(keys, visitor)
    
        def set(self, keys, mergeFn):
            """Sets the value associated with `((index, key), ...)`."""
            assert len(keys) == self.length
            self.top.set(keys, mergeFn)
    
    
<details><summary><h2>Identity key-value maps</h2></summary>

    class IdMap:
        def __init__(self):
            self.entries = dict()  # tuple of id -> (key, value)
            # the value's first element keeps the ids stable.
    
        def get(self, keys, default=None):
            ids = tuple(id(key) for key in keys)
            return self.entries.get(ids, (None, default))[1]
    
        def __contains__(self, keys):
            ids = tuple(id(key) for key in keys)
            return ids in self.entries
    
        def __getitem__(self, keys):
            ids = tuple(id(key) for key in keys)
            return self.entries[ids][1]
    
        def __setitem__(self, keys, value):
            ids = tuple(id(key) for key in keys)
            self.entries[ids] = (keys, value)
    
    

</details>
    
     
## Cached `map_reduce`

As mentioned earlier, we assume `reduce` is determined implicitly
by the reduced values' type.  We also have
`enumerate_supporting_values` to find all the closed over values
that yield a non-zero result, for all values in a sequence.

We can thus accept a function and an input sequence, find the
supporting values, and merge the result associated with identical
supporting values.

Again, we only support ground equality constraints (see assertion
on L568), i.e., only equijoins.  There's nothing that stops a more
sophisticated implementation from using range trees to support
inequality or range joins.

We'll cache the precomputed values by code object (i.e., function
without closed over values) and input sequence.  If we don't have a
precomputed value, we'll use `enumerate_supporting_values` to run
the function backward for each input datum from the sequence, and
accumulate the results in a `NestedDict`.  Working backward to find
closure values that yield a non-zero result (for each input datum)
lets us precompute a branching program that directly yields the
result.  We represent these branching programs explicitly, so we
can also directly update a branching program for the result of
merging all the values returned by mapping over the input sequence,
for a given closure.

This last `map_reduce` definition ties everything together, and
I think is really the general heart of Yannakakis's algorithm
as an instance of bottom-up dynamic programming.
    
    def _merge(dst, update):
        if dst is None:
            return update
    
        if isinstance(dst, (tuple, list)):
            assert len(dst) == len(update)
            for value, new in zip(dst, update):
                value.merge(new)
        else:
            dst.merge(update)
        return dst
    
    
    def _extractValues(accumulator):
        if accumulator is None:
            return None
        if isinstance(accumulator, tuple):
            return tuple(item.value for item in accumulator)
        if isinstance(accumulator, list):
            return list(item.value for item in accumulator)
    
        return accumulator.value
    
    def _precompute_map_reduce(function, depth, inputIterable):
        """Given a function (a closure), the number of values the function
        closes over, and an input iterable, generates a `NestedDict`
        representation for `reduce(map(function, inputIterable))`, where
        the reduction step simply calls `merge` on the return values
        (tuples are merged elementwise), and the `NestedDict` keys
        represent closed over values.
    
        >>> def count_eql(needle): return lambda x: Sum(1) if x == needle else None
        >>> nd = _precompute_map_reduce(count_eql(4), 1, [1, 2, 4, 4, 2])
        >>> nd.visit((0,), lambda sum: print(sum.value))
        >>> nd.visit((1,), lambda sum: print(sum.value))
        1
        >>> nd.visit((2,), lambda sum: print(sum.value))
        2
        >>> nd.visit((4,), lambda sum: print(sum.value))
        2
        """
        cache = NestedDict(depth)
        for indexKeyValues, result in enumerate_supporting_values(function, inputIterable):
            cache.set(indexKeyValues, lambda old: _merge(old, result))
        return cache
    
    
    AGGREGATE_CACHE = IdMap()  # Map from function, input sequence -> NestedDict
    
    
    def map_reduce(function, inputIterable, initialValue=None, *, extractResult=True):
        """Returns the result of merging `map(function, inputIterable)`
        into `initialValue`.
    
        `None` return values represent neutral elements (i.e., the result
        of mapping an empty `inputIterable`), and values are otherwise
        reduced by calling `merge` on a mutable accumulator.
    
        Assuming `function` is well-behaved, `map_reduce` runs in time
        linear wrt `len(inputIterable)`.  It's also always cached on a
        composite key that consists of the `function`'s code object (i.e.,
        without closed over values) and the `inputIterable`.
    
        These complexity guarantees let us nest `map_reduce` with
        different closed over values, and still guarantee a linear-time
        total complexity.
    
        This wrapper ties together all the components
    
        >>> INVOCATION_COUNTER = 0
        >>> data = (1, 2, 2, 4, 2, 4)
        >>> def count_eql(needle):
        ...     def count(x):
        ...         global INVOCATION_COUNTER
        ...         INVOCATION_COUNTER += 1
        ...         return Sum(x) if x == needle else None
        ...     return count
        >>> INVOCATION_COUNTER
        0
        >>> map_reduce(count_eql(4), data, Sum(), extractResult=False).value
        8
        >>> INVOCATION_COUNTER
        18
        >>> map_reduce(count_eql(2), data)
        6
        >>> INVOCATION_COUNTER
        18
        >>> id_skus = [(1, 2), (2, 2), (1, 3)]
        >>> sku_costs = [(1, 10), (2, 20), (3, 30)]
        >>> def sku_min_cost(sku):
        ...     return map_reduce(lambda sku_cost: Min(sku_cost[1]) if sku_cost[0] == sku else None, sku_costs, Min(0))
        >>> def sum_odd_or_even_skus(mod_two):
        ...     def count_if_mod_two(id_sku):
        ...         id, sku = id_sku
        ...         if id % 2 == mod_two:
        ...             return Sum(sku_min_cost(sku))
        ...     return map_reduce(count_if_mod_two, id_skus, Sum())
        >>> sum_odd_or_even_skus(0)
        20
        >>> sum_odd_or_even_skus(1)
        50
    
        """
        assert isinstance(inputIterable, collections.abc.Iterable)
        assert not isinstance(inputIterable, collections.abc.Iterator)
        code, closure, *_ = extract_function_state(function)
        if (code, inputIterable) not in AGGREGATE_CACHE:
            AGGREGATE_CACHE[code, inputIterable] = _precompute_map_reduce(
                function, len(closure), inputIterable
            )
    
        acc = [initialValue]
        def visitor(result):
            acc[0] = _merge(acc[0], result)
    
        AGGREGATE_CACHE[code, inputIterable].visit(closure, visitor)
        return _extractValues(acc[0]) if extractResult else acc[0]
    
    
    map_reduce.over = \
        lambda inputIterable, initialValue=None, *, extractResult=True: \
            lambda fn: map_reduce(fn, inputIterable, initialValue, extractResult=extractResult)
    
     
    if __name__ == "__main__":
        import doctest
    
        doctest.testmod()
    
    
## Is this actually a DB post?

Although the intro name-dropped Yannakakis, the presentation here
has a very programming language / logic programming flavour.  I
think the logic programming point of view, where we run a program
backwards with logical variables, is much clearer than the specific
case of conjunctive equijoin queries in the usual presentation of
Yannakakis's algorithm.  In particular, I think there's a clear
path to handle range or comparison joins: it's all about having an
index data structure to handle range queries.

It should be clear how to write conjunctive queries as Python
functions, given a hypertree decomposition.  The reverse is much
more complex, if only because Python is much more powerful than
just CQ, and that's actually a liability: this hack will blindly
try to convert any function to a branching program, instead of
giving up noisily when the function is too complex.

The other difference from classical CQs is that we focus on
aggregates.  That's because aggregates are the more general form:
if we just want to avoid useless work while enumerating all join
rows, we only need a boolean aggregate that tells us whether the
join will yield at least one row.  We could also special case types
for which merges don't save space (e.g., set of row ids), and
instead enumerate values by walking the branching program tree.

The aggregate viewpoint also works for
[fun extensions like indexed access to ranked results](https://ntzia.github.io/download/Tractable_Orders_2020.pdf):
that extension ends up counting the number of output values up to a
certain key.

I guess, in a way, we just showed a trivial way to decorrelate
queries with a hypertree-width of 1.  We just have to be OK with
building one index for each loop in the nest... but it should be
possible to pattern match on pre-defined indexes and avoid obvious
redundancy.

## Extensions and future work

### Use a dedicated DSL

First, the whole idea of introspecting closures to stub in logical
variable is a terrible hack (looks cool though ;).  A real
production implementation should apply CPS partial evaluation to a
purely functional programming language, then bulk reverse-evaluate
with a SIMD implementation of the logical program.

There'll be restrictions on the output traces, but that's OK: a
different prototype makes me believe the restrictions correspond to
deterministic logspace (L), and it makes sense to restrict our
analyses to L.  Just like grammars are easier to work with when
restricted to LL(1), DSLs that only capture L tend to be easier to
analyse and optimise... and L is reasonably larger (a
polynomial-time algorithm that's not in L would be a *huge*
result).

### Handle local functions

While we sin with the closure hack (`extract_function_state`) it
should really be extended to cover local functions.  This is mostly
a question of going deeply into values that are mapped to
functions, and of maintaining an id-keyed map from cell to
`OpaqueValue`.

We could also add support for partial application objects, which
may be easier for multiprocessing.

### Parallelism

There is currently no support for parallelism, only caching.  It
should be easy to handle the return values (`NestedDict`s and
aggregate classes like `Sum` or `Min`).  Distributing the work in
`_precompute_map_reduce` to merge locally is also not hard.

The main issue with parallelism is that we can't pass functions
as work units, so we'd have to stick to the `fork` process pool.

There's also no support for moving (child) work forward when
blocked waiting on a future.  We'd have to spawn workers on the fly
to oversubscribe when workers are blocked on a result (spawning on
demand is already a given for `fork` workers), and to implement our
own concurrency control to avoid wasted work, and probably internal
throttling to avoid thrashing when we'd have more active threads
than cores.

That being said, the complexity is probably worth the speed up on
realistic queries.

### Theta joins

At a higher level, we could support comparison joins (e.g., less
than, greater than or equal, in range) if only we represented the
branching programs with a data structure that supported these
queries.  A [range tree](https://dl.acm.org/doi/10.1145/356789.356797) would
let us handle these "theta" joins, for tbe low low cost of a
polylogarithmic multiplicative factor in space and time.

### Self-adjusting computation

Finally, we could update the indexed branching programs
incrementally after small changes to the input data.  This might
sound like a job for streaming engines like [timely dataflow](https://github.com/timelydataflow/timely-dataflow),
but I think viewing each `_precompute_map_reduce` call as a purely
functional map/reduce job gives a better fit with [self-adjusting computation](https://www.umut-acar.org/research#h.x3l3dlvx3g5f).

Once we add logic to recycle previously constructed indexes, it
will probably make sense to allow an initial filtering step before
map/reduce, with a cache key on the filter function (with closed
over values and all).  We can often implement the filtering more
efficiently than we can run functions backward, and we'll also
observe that slightly different filter functions often result
in not too dissimilar filtered sets.  Factoring out this filtering
can thus enable more reuse of partial precomputed results.