#!/usr/bin/env sed -re s|^|\x20\x20\x20\x20| -e s|^\x20{4}\x23\x23{(.*)$|<details><summary>\1</summary>\n| -e s|^\x20{4}\x23\x23}$|\n</details>| -e s|^\x20{4}\x23\x23\x20?|| -e s|\x0c|\x20| <details><summary>license, imports</summary> # Yannakakis.py by Paul Khuong # # To the extent possible under law, the person who associated CC0 with # Yannakakis.py has waived all copyright and related or neighboring rights # to Yannakakis.py. # # You should have received a copy of the CC0 legalcode along with this # work. If not, see <http://creativecommons.org/publicdomain/zero/1.0/>. import collections.abc </details> # Linear-time analytical queries in plain Python <small>I didn't know [Mihalis Yannakakis is about to turn 70](https://mihalisfest.cs.columbia.edu), but that's a fun coïncidence.</small> This short hack shows how [Yannakakis's algorithm](https://pages.cs.wisc.edu/~paris/cs784-f19/lectures/lecture4.pdf) lets us implement linear-time (wrt the number of input data rows) analytical queries in regular programming languages, without offloading joins to a specialised database query language, thus avoiding the associated impedance mismatch. There are restrictions on the queries we can express -- Yannakakis's algorithm relies on a hypertree width of 1, and on having hypertree decomposition as a witness -- but that's kind of reasonable: a fractional hypertree width > 1 would mean there are databases for which the intermediate results could superlinearly larger than the input database due to the [AGM bound](https://arxiv.org/abs/1711.03860). The hypertree decomposition witness isn't a burden either: structured programs naturally yield a hypertree decomposition, unlike more declarative logic programs that tend to hide the structure implicit in the programmer's thinking. The key is to mark function arguments as either negative (true inputs) or positive (possible interesting input values derived from negative arguments). In this hack, closed over values are positive, and the only negative argument is the current data row. We also assume that these functions are always used in a map/reduce pattern, and thus we only memoise the result of `map_reduce(function, input)`, with a group-structured reduction: the reduce function must be associative and commutative, and there must be a zero (neutral) value. With these constraints, we can express joins in natural Python without incurring the poor runtime scaling of the nested loops we actually wrote. This Python file describes the building blocks to handle aggregation queries like the following >>> id_skus = [(1, 2), (2, 2), (1, 3)] >>> sku_costs = [(1, 10), (2, 20), (3, 30)] >>> def sum_odd_or_even_skus(mod_two): ... @map_reduce.over(id_skus, Sum()) ... def count_if_mod_two(id_sku): ... id, sku = id_sku ... if id % 2 == mod_two: ... @map_reduce.over(sku_costs, Min(0)) ... def min_cost(sku_cost): ... if sku_cost[0] == sku: ... return Min(sku_cost[1]) ... return Sum(min_cost) ... return count_if_mod_two ... with linear scaling in the length of `id_skus` and `sku_costs`, and caching for similar queries. At a small scale, everything's fast. >>> begin = time.time(); print(sum_odd_or_even_skus(0)); print(time.time() - begin) 20 0.0007169246673583984 >>> begin = time.time(); print(sum_odd_or_even_skus(1)); print(time.time() - begin) 50 0.0002627372741699219 As we increase the scale by 1000x for both input lists, the runtime scales (sub) linearly for the first query, and is unchanged for the second: >>> id_skus = id_skus * 1000 >>> sku_costs = sku_costs * 1000 >>> begin = time.time(); print(sum_odd_or_even_skus(0)); print(time.time() - begin) 20000 0.09455370903015137 >>> begin = time.time(); print(sum_odd_or_even_skus(1)); print(time.time() - begin) 50000 0.00025773048400878906 This still pretty much holds up when we multiply by another factor of 100: >>> id_skus = id_skus * 100 >>> sku_costs = sku_costs * 100 >>> begin = time.time(); print(sum_odd_or_even_skus(0)); print(time.time() - begin) 2000000 6.946590185165405 >>> begin = time.time(); print(sum_odd_or_even_skus(1)); print(time.time() - begin) 5000000 0.00025200843811035156 The magic behind the curtains is memoisation (unsurprisingly), but a special implementation that can share work for similar closures: the memoisation key consists of the function *without closed over bindings* and the call arguments, while the memoised value is a data structure from the tuple of closed over values to the `map_reduce_over` output. This concrete representation of the function as a branching program is the core of Yannakakis's algorithm: we'll iterate over each datum in the input, run the function on it with logical variables instead of the closed over values, and generate a mapping from closed over values to result for all non-zero results. We'll then merge the mappings for all input data together (there is no natural ordering here, hence the group structure). The output data structure controls the join we can implement. We show how a simple stack of nested key-value mappings handles equijoins, but a [k-d range tree](https://dl.acm.org/doi/10.1145/356789.356797) would handle inequalities, i.e., "theta" joins (the rest of the machinery already works in terms of less than/greater than constraints). As long as we explore a bounded number of paths for each datum and a bounded number of `function, input` cache keys, we'll spend a bounded amount of time on each input datum, and thus linear time total. The magic of Yannakakis's algorithm is that this works even when there are nested `map_reduce` calls, which would naïvely result in polynomial time (degree equal to the nesting depth). <details><summary><h2>Memoising through a Python function's closed over values</h2></summary> Even if functions were hashable and comparable for extensional equality, directly using closures as memoisation keys in calls like `map_reduce(function, input)` would result in superlinear runtime for nested `map_reduce` calls. This `extract_function_state` accepts a function (with or without closed over state), and returns four values: 1. The underlying code object 2. The tuple of closed over values (current value for mutable cells) 3. A function to rebind the closure with new closed over values 4. The name of the closed over bindings The third return value, the `rebind` function, silently fails on complicated cases; this is a hack, after all. In short, it only handles closing over immutable atomic values like integers or strings, but not, e.g., functions (yet), or mutable bindings. def extract_function_state(function): """Accepts a function object and returns information about it: a hash key for the object, a tuple of closed over values, a function to return a fresh closure with a different tuple of closed over values, and a closure of closed over names >>> def test(x): return lambda y: x == y >>> extract_function_state(test)[1] () >>> extract_function_state(test)[3] () >>> fun = test(4) >>> extract_function_state(fun)[1] (4,) >>> extract_function_state(fun)[3] ('x',) >>> >>> fun(4) True >>> fun(5) False >>> rebind = extract_function_state(fun)[2] >>> rebound_4, rebound_5 = rebind([4]), rebind([5]) >>> rebound_4(4) True >>> rebound_4(5) False >>> rebound_5(4) False >>> rebound_5(5) True """ code = function.__code__ names = code.co_freevars if function.__closure__ is None: # Toplevel function assert names == () def rebind(values): if len(values) != 0: raise RuntimeError( f"Values must be empty for toplevel function. values={values}" ) return function return code, (), rebind, names closure = tuple(cell.cell_contents for cell in function.__closure__) assert len(names) == len(closure), (closures, names) # TODO: rebind recursively (functions are also cells) def rebind(values): if len(values) != len(names): raise RuntimeError( f"Values must match names. names={names} values={values}" ) return function.__class__( code, function.__globals__, function.__name__, function.__defaults__, tuple( cell.__class__(value) for cell, value in zip(function.__closure__, values) ), ) return code, closure, rebind, names </details> ## Logical variables for closed-over values We wish to enumerate the support of a function call (parameterised over closed over values), and the associated result value. We'll do that by rebinding the closure to point at instances of `OpaqueValue` and enumerating all the possible constraints on these `OpaqueValue`s. These `OpaqueValue`s work like logical variables that let us run a function in reverse: when get a non-zero (non-None) return value, we look at the accumulated constraint set on the opaque values and use them to update the data representation of the function's result (we assume that we can represent the constraints on all `OpaqueValue`s in our result data structure). Currently, we only support nested dictionaries, so each `OpaqueValue`s must be either fully unconstrained (wildcard that matches any value), or constrained to be exactly equal to a value. There's no reason we can't use k-d range trees though, and it's not harder to track a pair of bounds (lower and upper) than a set of inequalities, so we'll handle the general ordered `OpaqueValue` case. In the input program (the query), we assume closed over values are only used for comparisons (equality, inequality, relational operators, or conversion to bool, i.e., non-zero testing). Knowing the result of each (non-redundant) comparison tightens the range of potential values for the `OpaqueValue`... eventually down to a single point value that our hash-based indexes can handle. Of course, if a comparison isn't redundant, there are multiple feasible results, so we need an external oracle to pick one. An external caller is responsible for injecting its logic as `OpaqueValue.CMP_HANDLER`, and driving the exploration of the search space. N.B., the set of constraints we can handle is determined by the ground data structure to represent finitely supported functions. class OpaqueValue: """An opaque value is a one-dimensional range of Python values, represented as a lower and an upper bound, each of which is potentially exclusive. `OpaqueValue`s are only used in queries for comparisons with ground values. All comparisons are turned into three-way `__cmp__` calls; non-redundant `__cmp__` calls (which could return more than one value) are resolved by calling `CMP_HANDLER` and tightening the bound in response. >>> x = OpaqueValue("x") >>> x == True True >>> 1 if x else 2 1 >>> x.reset() >>> ### Not supported by our index data structure (yet) >>> # >>> x > 4 >>> # False >>> # >>> x < 4 >>> # False >>> # >>> x == 4 >>> # True >>> # >>> x < 10 >>> # True >>> # >>> x >= 10 >>> # False >>> # >>> x > -10 >>> # True >>> # >>> x <= -10 >>> # False """ # Resolving function for opaque values CMP_HANDLER = lambda opaque, value: 0 def __init__(self, name): self.name = name # Upper and lower bounds self.lower = self.upper = None self.lowerExcl = self.upperExcl = True def __str__(self): left = "(" if self.lowerExcl else "[" right = ")" if self.upperExcl else "]" return f"<OpaqueValue {self.name} {left}{self.lower}, {self.upper}{right}>" def reset(self): """Clears all accumulated constraints on this `OpaqueValue`.""" self.lower = self.upper = None self.lowerExcl = self.upperExcl = True def indefinite(self): """Returns whether this `OpaqueValue` is still unconstrained.""" return self.lower == None and self.upper == None def definite(self): """Returns whether is `OpaqueValue` is constrained to an exact value.""" return ( self.lower == self.upper and self.lower is not None and not self.lowerExcl and not self.upperExcl ) def value(self): """Returns the exact value for this `OpaqueValue`, assuming there is one.""" return self.lower def _contains(self, value, strictly): """Returns whether this `OpaqueValue`'s range includes `value` (maybe `strictly` inside the range). The first value is the containment truth value, and the second is the forced `__cmp__` value, if the `value` is *not* (strictly) contained in the range. """ if self.lower is not None: if self.lower > value: return False, 1 if self.lower == value and (strictly or self.lowerExcl): return False, 1 if self.upper is not None: if self.upper < value: return False, -1 if self.upper == value and (strictly or self.upperExcl): return False, -1 return True, 0 if self.definite() else None def contains(self, value, strictly=False): """Returns whether `values` is `strictly` contained in the `OpaqueValue`'s range. """ try: return self._contains(value, strictly)[0] except TypeError: return False def potential_mask(self, other): """Returns the set of potential `__cmp__` values for `other` that are compatible with the current range: bit 0 is set if `-1` is possible, bit 1 if `0` is possible, and bit 2 if `1` is possible. """ if not self.contains(other): return 0 if self.contains(other, strictly=True): return 7 if self.definite() and self.value == other: return 2 # We have a non-strict inclusion, and inequality. if self.lower == other and self.lowerExcl: return 6 assert self.upper == other and self.upperExcl return 3 def __cmp__(self, other): """Three-way comparison between this `OpaqueValue` and `other`. When the result is known from the current bound, we just return that value. Otherwise, we ask `CMP_HANDLER` what value to return and update the bound accordingly. """ if isinstance(other, OpaqueValue) and self.definite() and not other.definite(): # If we have a definite value and `other` is an indefinite # `OpaqueValue`, flip the comparison order to let the `other` # argument be a ground value. return -other.__cmp__(self.value()) if isinstance(other, OpaqueValue) and not other.definite(): raise RuntimeError( f"OpaqueValue may only be compared with ground values. self={self} other={other}" ) if isinstance(other, OpaqueValue): other = other.value() # Make sure `other` is a ground value if other is None: # We use `None` internally, and it doesn't compare well raise RuntimeError("OpaqueValue may not be compared with None") compatible, order = self._contains(other, False) if order is not None: return order order = OpaqueValue.CMP_HANDLER(self, other) if order < 0: self._add_bound(upper=other, upperExcl=True) elif order == 0: self._add_bound(lower=other, lowerExcl=False, upper=other, upperExcl=False) else: self._add_bound(lower=other, lowerExcl=True) return order def _add_bound(self, lower=None, lowerExcl=False, upper=None, upperExcl=False): """Updates the internal range for this new bound.""" assert lower is None or self.contains(lower, strictly=lowerExcl) assert upper is None or self.contains(upper, strictly=upperExcl) if lower is not None: self.lower = lower self.lowerExcl = lowerExcl assert upper is None or self.contains(upper, strictly=upperExcl) if upper is not None: self.upper = upper self.upperExcl = upperExcl def __bool__(self): return self != 0 def __eq__(self, other): return self.__cmp__(other) == 0 def __ne__(self, other): return self.__cmp__(other) != 0 # No other comparator because we don't index ranges # (no range tree). # def __lt__(self, other): # return self.__cmp__(other) < 0 # def __le__(self, other): # return self.__cmp__(other) <= 0 # def __gt__(self, other): # return self.__cmp__(other) > 0 # def __ge__(self, other): # return self.__cmp__(other) >= 0 ## Depth-first exploration of a function call's support We assume a `None` result represents a zero value wrt the aggregate merging function (e.g., 0 for a sum). For convenience, we also treat tuples and lists of `None`s identically. We simply maintain a stack of `CMP_HANDLER` calls, where each entry in the stack consists of an `OpaqueValue` and bitset of `CMP_HANDLER` results still to explore (-1, 0, or 1). This stack is filled on demand, and `CMP_HANDLER` returns the first result allowed by the bitset. Once we have a result, we tweak the stack to force depth-first exploration of a different part of the solution space: we drop the first bit in the bitset of results to explore, and drop the entry wholesale if the bitset is now empty (all zero). When this tweaking leaves an empty stack, we're done. This ends up enumerating all the paths through the function call with a non-recursive depth-first traversal. We then do the same for each datum in our input sequence, and merge results for identical keys together. def is_zero_result(value): """Checks if `value` is a "zero" aggregate value: either `None`, or an iterable of all `None`. >>> is_zero_result(None) True >>> is_zero_result(False) False >>> is_zero_result(True) False >>> is_zero_result(0) False >>> is_zero_result(-10) False >>> is_zero_result(1.5) False >>> is_zero_result("") False >>> is_zero_result("asd") False >>> is_zero_result((None, None)) True >>> is_zero_result((None, 1)) False >>> is_zero_result([]) True >>> is_zero_result([None]) True >>> is_zero_result([None, (None, None)]) False """ if value is None: return True if isinstance(value, (tuple, list)): return all(item is None for item in value) return False def enumerate_opaque_values(function, values): """Explores the set of `OpaqueValue` constraints when calling `function`. Enumerates all constraints for the `OpaqueValue` instances in `values`, and yields a pair of equality constraints for the `value` and the corresponding result, for all non-zero results. This essentially turns `function()` into a branching program on `values`. >>> x, y = OpaqueValue("x"), OpaqueValue("y") >>> list(enumerate_opaque_values(lambda: 1 if x == 0 else (2 if x == 1 and y == 2 else None), [x, y])) [((0, None), 1), ((1, 2), 2)] """ explorationStack = [] # List of (value, bitmaskOfCmp) while True: for value in values: value.reset() stackIndex = 0 def handle(value, other): nonlocal stackIndex if len(explorationStack) == stackIndex: explorationStack.append((value, value.potential_mask(other))) expectedValue, mask = explorationStack[stackIndex] assert value is expectedValue assert mask != 0 if (mask & 1) != 0: ret = -1 elif (mask & 2) != 0: ret = 0 elif (mask & 4) != 0: ret = 1 else: assert False, f"bad mask {mask}" stackIndex += 1 return ret OpaqueValue.CMP_HANDLER = handle result = function() if not is_zero_result(result): for value in values: assert ( value.definite() or value.indefinite() ), f"partially constrained {value} temporarily unsupported" yield (tuple(key.value() if key.definite() else None for key in values), result) # Drop everything that was fully explored, then move the next # top of stack to the next option. while explorationStack: value, mask = explorationStack[-1] assert 0 <= mask < 8 mask &= mask - 1 # Drop first bit if mask != 0: explorationStack[-1] = (value, mask) break explorationStack.pop() if not explorationStack: break def enumerate_supporting_values(function, args): """Lists the bag of mapping from closed over values to non-zero result, for all calls `function(arg) for args in args`. >>> def count_eql(needle): return lambda x: 1 if x == needle else None >>> list(enumerate_supporting_values(count_eql(4), [1, 2, 4, 4, 2])) [((1,), 1), ((2,), 1), ((4,), 1), ((4,), 1), ((2,), 1)] """ _, _, rebind, names = extract_function_state(function) values = [OpaqueValue(name) for name in names] reboundFunction = rebind(values) for arg in args: yield from enumerate_opaque_values(lambda: reboundFunction(arg), values) ## Type driven merges The interesting part of map/reduce is the reduction step. While some like to use first-class functions to describe reduction, in my opinion, it often makes more sense to define reduction at the type level: it's essential that merge operators be commutative and associative, so isolating the merge logic in dedicated classes makes sense to me. This file defines a two trivial mergeable value type, `Sum` and `Min`, but we could have different ones, e.g., hyperloglog unique counts, or streaming statistical moments... or even a list of row ids. class Sum: """A counter for summed values.""" def __init__(self, value=0): self.value = value def merge(self, other): assert isinstance(other, Sum) self.value += other.value class Min: """A running `min` value tracker.""" def __init__(self, value=None): self.value = value def merge(self, other): assert isinstance(other, Min) if not self.value: self.value = other.value elif other.value: self.value = min(self.value, other.value) ## Nested dictionary with wildcard There's a direct relationship between the data structure we use to represent the result of function calls as branching functions, and the constraints we can support on closed over values for non-zero results. In a real implementation, this data structure would host most of the complexity: it's the closest thing we have to indexes. For now, support equality *with ground value* as our only constraint. This means the only two cases we must look for at each level are an exact match, or a wildcard match. We do have to check for both cases at each level, so the worst-case complexity for lookups is exponential in the depth (number of join variables). That's actually reasonable because we don't expect too many join variables (compare to [range trees](https://en.wikipedia.org/wiki/Range_tree#Range_queries) that are also exponential in the number of dimensions... but with a base of log(n) instead of 2). A real implementation could maybe save work by memoising merged results for internal subtrees... it's tempting to broadcast wildcard values to the individual keyed entries, but I think that might explode the time complexity of our pre-computation phase. class NestedDictLevel: """One level in a nested dictionary index. We may have a value for everything (leaf node), or a key-value dict *and a wildcard entry* for a specific index in the tuple key. """ def __init__(self, depth): self.depth = depth self.value = None self.wildcard = None self.dict = dict() def visit(self, keys, visitor): """Passes the values for `keys` to `visitor`.""" if self.value is not None: assert self.wildcard is None and not self.dict visitor(self.value) if self.depth >= len(keys): return if self.wildcard is not None: self.wildcard.visit(keys, visitor) next = self.dict.get(keys[self.depth], None) if next is not None: next.visit(keys, visitor) def set(self, keys, mergeFunction, depth=0): """Sets the value for `keys` in this level.""" assert depth <= len(keys) if depth == len(keys): # Leaf self.value = mergeFunction(self.value) return assert self.depth == depth key = keys[depth] if key is None: if self.wildcard is None: self.wildcard = NestedDictLevel(depth + 1) dst = self.wildcard else: dst = self.dict.get(key) if dst is None: dst = NestedDictLevel(depth + 1) self.dict[key] = dst dst.set(keys, mergeFunction, depth + 1) class NestedDict: """A nested dict of a given `depth` maps tuples of `depth` keys to a value. Each `NestedDictLevel` handles a different level. """ def __init__(self, length): self.top = NestedDictLevel(0) self.length = length def visit(self, keys, visitor): """Gets the value associated with `keys`, or `default` if None.""" assert len(keys) == self.length assert all(key is not None for key in keys) self.top.visit(keys, visitor) def set(self, keys, mergeFn): """Sets the value associated with `((index, key), ...)`.""" assert len(keys) == self.length self.top.set(keys, mergeFn) <details><summary><h2>Identity key-value maps</h2></summary> class IdMap: def __init__(self): self.entries = dict() # tuple of id -> (key, value) # the value's first element keeps the ids stable. def get(self, keys, default=None): ids = tuple(id(key) for key in keys) return self.entries.get(ids, (None, default))[1] def __contains__(self, keys): ids = tuple(id(key) for key in keys) return ids in self.entries def __getitem__(self, keys): ids = tuple(id(key) for key in keys) return self.entries[ids][1] def __setitem__(self, keys, value): ids = tuple(id(key) for key in keys) self.entries[ids] = (keys, value) </details> ## Cached `map_reduce` As mentioned earlier, we assume `reduce` is determined implicitly by the reduced values' type. We also have `enumerate_supporting_values` to find all the closed over values that yield a non-zero result, for all values in a sequence. We can thus accept a function and an input sequence, find the supporting values, and merge the result associated with identical supporting values. Again, we only support ground equality constraints (see assertion on L568), i.e., only equijoins. There's nothing that stops a more sophisticated implementation from using range trees to support inequality or range joins. We'll cache the precomputed values by code object (i.e., function without closed over values) and input sequence. If we don't have a precomputed value, we'll use `enumerate_supporting_values` to run the function backward for each input datum from the sequence, and accumulate the results in a `NestedDict`. Working backward to find closure values that yield a non-zero result (for each input datum) lets us precompute a branching program that directly yields the result. We represent these branching programs explicitly, so we can also directly update a branching program for the result of merging all the values returned by mapping over the input sequence, for a given closure. This last `map_reduce` definition ties everything together, and I think is really the general heart of Yannakakis's algorithm as an instance of bottom-up dynamic programming. def _merge(dst, update): if dst is None: return update if isinstance(dst, (tuple, list)): assert len(dst) == len(update) for value, new in zip(dst, update): value.merge(new) else: dst.merge(update) return dst def _extractValues(accumulator): if accumulator is None: return None if isinstance(accumulator, tuple): return tuple(item.value for item in accumulator) if isinstance(accumulator, list): return list(item.value for item in accumulator) return accumulator.value def _precompute_map_reduce(function, depth, inputIterable): """Given a function (a closure), the number of values the function closes over, and an input iterable, generates a `NestedDict` representation for `reduce(map(function, inputIterable))`, where the reduction step simply calls `merge` on the return values (tuples are merged elementwise), and the `NestedDict` keys represent closed over values. >>> def count_eql(needle): return lambda x: Sum(1) if x == needle else None >>> nd = _precompute_map_reduce(count_eql(4), 1, [1, 2, 4, 4, 2]) >>> nd.visit((0,), lambda sum: print(sum.value)) >>> nd.visit((1,), lambda sum: print(sum.value)) 1 >>> nd.visit((2,), lambda sum: print(sum.value)) 2 >>> nd.visit((4,), lambda sum: print(sum.value)) 2 """ cache = NestedDict(depth) for indexKeyValues, result in enumerate_supporting_values(function, inputIterable): cache.set(indexKeyValues, lambda old: _merge(old, result)) return cache AGGREGATE_CACHE = IdMap() # Map from function, input sequence -> NestedDict def map_reduce(function, inputIterable, initialValue=None, *, extractResult=True): """Returns the result of merging `map(function, inputIterable)` into `initialValue`. `None` return values represent neutral elements (i.e., the result of mapping an empty `inputIterable`), and values are otherwise reduced by calling `merge` on a mutable accumulator. Assuming `function` is well-behaved, `map_reduce` runs in time linear wrt `len(inputIterable)`. It's also always cached on a composite key that consists of the `function`'s code object (i.e., without closed over values) and the `inputIterable`. These complexity guarantees let us nest `map_reduce` with different closed over values, and still guarantee a linear-time total complexity. This wrapper ties together all the components >>> INVOCATION_COUNTER = 0 >>> data = (1, 2, 2, 4, 2, 4) >>> def count_eql(needle): ... def count(x): ... global INVOCATION_COUNTER ... INVOCATION_COUNTER += 1 ... return Sum(x) if x == needle else None ... return count >>> INVOCATION_COUNTER 0 >>> map_reduce(count_eql(4), data, Sum(), extractResult=False).value 8 >>> INVOCATION_COUNTER 18 >>> map_reduce(count_eql(2), data) 6 >>> INVOCATION_COUNTER 18 >>> id_skus = [(1, 2), (2, 2), (1, 3)] >>> sku_costs = [(1, 10), (2, 20), (3, 30)] >>> def sku_min_cost(sku): ... return map_reduce(lambda sku_cost: Min(sku_cost[1]) if sku_cost[0] == sku else None, sku_costs, Min(0)) >>> def sum_odd_or_even_skus(mod_two): ... def count_if_mod_two(id_sku): ... id, sku = id_sku ... if id % 2 == mod_two: ... return Sum(sku_min_cost(sku)) ... return map_reduce(count_if_mod_two, id_skus, Sum()) >>> sum_odd_or_even_skus(0) 20 >>> sum_odd_or_even_skus(1) 50 """ assert isinstance(inputIterable, collections.abc.Iterable) assert not isinstance(inputIterable, collections.abc.Iterator) code, closure, *_ = extract_function_state(function) if (code, inputIterable) not in AGGREGATE_CACHE: AGGREGATE_CACHE[code, inputIterable] = _precompute_map_reduce( function, len(closure), inputIterable ) acc = [initialValue] def visitor(result): acc[0] = _merge(acc[0], result) AGGREGATE_CACHE[code, inputIterable].visit(closure, visitor) return _extractValues(acc[0]) if extractResult else acc[0] map_reduce.over = \ lambda inputIterable, initialValue=None, *, extractResult=True: \ lambda fn: map_reduce(fn, inputIterable, initialValue, extractResult=extractResult) if __name__ == "__main__": import doctest doctest.testmod() ## Is this actually a DB post? Although the intro name-dropped Yannakakis, the presentation here has a very programming language / logic programming flavour. I think the logic programming point of view, where we run a program backwards with logical variables, is much clearer than the specific case of conjunctive equijoin queries in the usual presentation of Yannakakis's algorithm. In particular, I think there's a clear path to handle range or comparison joins: it's all about having an index data structure to handle range queries. It should be clear how to write conjunctive queries as Python functions, given a hypertree decomposition. The reverse is much more complex, if only because Python is much more powerful than just CQ, and that's actually a liability: this hack will blindly try to convert any function to a branching program, instead of giving up noisily when the function is too complex. The other difference from classical CQs is that we focus on aggregates. That's because aggregates are the more general form: if we just want to avoid useless work while enumerating all join rows, we only need a boolean aggregate that tells us whether the join will yield at least one row. We could also special case types for which merges don't save space (e.g., set of row ids), and instead enumerate values by walking the branching program tree. The aggregate viewpoint also works for [fun extensions like indexed access to ranked results](https://ntzia.github.io/download/Tractable_Orders_2020.pdf): that extension ends up counting the number of output values up to a certain key. I guess, in a way, we just showed a trivial way to decorrelate queries with a hypertree-width of 1. We just have to be OK with building one index for each loop in the nest... but it should be possible to pattern match on pre-defined indexes and avoid obvious redundancy. ## Extensions and future work ### Use a dedicated DSL First, the whole idea of introspecting closures to stub in logical variable is a terrible hack (looks cool though ;). A real production implementation should apply CPS partial evaluation to a purely functional programming language, then bulk reverse-evaluate with a SIMD implementation of the logical program. There'll be restrictions on the output traces, but that's OK: a different prototype makes me believe the restrictions correspond to deterministic logspace (L), and it makes sense to restrict our analyses to L. Just like grammars are easier to work with when restricted to LL(1), DSLs that only capture L tend to be easier to analyse and optimise... and L is reasonably larger (a polynomial-time algorithm that's not in L would be a *huge* result). ### Handle local functions While we sin with the closure hack (`extract_function_state`) it should really be extended to cover local functions. This is mostly a question of going deeply into values that are mapped to functions, and of maintaining an id-keyed map from cell to `OpaqueValue`. We could also add support for partial application objects, which may be easier for multiprocessing. ### Parallelism There is currently no support for parallelism, only caching. It should be easy to handle the return values (`NestedDict`s and aggregate classes like `Sum` or `Min`). Distributing the work in `_precompute_map_reduce` to merge locally is also not hard. The main issue with parallelism is that we can't pass functions as work units, so we'd have to stick to the `fork` process pool. There's also no support for moving (child) work forward when blocked waiting on a future. We'd have to spawn workers on the fly to oversubscribe when workers are blocked on a result (spawning on demand is already a given for `fork` workers), and to implement our own concurrency control to avoid wasted work, and probably internal throttling to avoid thrashing when we'd have more active threads than cores. That being said, the complexity is probably worth the speed up on realistic queries. ### Theta joins At a higher level, we could support comparison joins (e.g., less than, greater than or equal, in range) if only we represented the branching programs with a data structure that supported these queries. A [range tree](https://dl.acm.org/doi/10.1145/356789.356797) would let us handle these "theta" joins, for tbe low low cost of a polylogarithmic multiplicative factor in space and time. ### Self-adjusting computation Finally, we could update the indexed branching programs incrementally after small changes to the input data. This might sound like a job for streaming engines like [timely dataflow](https://github.com/timelydataflow/timely-dataflow), but I think viewing each `_precompute_map_reduce` call as a purely functional map/reduce job gives a better fit with [self-adjusting computation](https://www.umut-acar.org/research#h.x3l3dlvx3g5f). Once we add logic to recycle previously constructed indexes, it will probably make sense to allow an initial filtering step before map/reduce, with a cache key on the filter function (with closed over values and all). We can often implement the filtering more efficiently than we can run functions backward, and we'll also observe that slightly different filter functions often result in not too dissimilar filtered sets. Factoring out this filtering can thus enable more reuse of partial precomputed results.