Last active
September 27, 2022 00:24
-
-
Save msullivan/b45e3c785c55d5d8ea4db196e1d9cb55 to your computer and use it in GitHub Desktop.
algorithm for eliminating argument tuples
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from dataclasses import dataclass | |
import json | |
import time | |
import textwrap | |
from typing import Any, Callable | |
@dataclass(eq=False) | |
class Scalar: | |
name: str = 'T' | |
def fmt(self) -> str: | |
return self.name | |
@dataclass(eq=False) | |
class Tuple: | |
typs: tuple[Type, ...] | |
def fmt(self) -> str: | |
return f'tuple<{", ".join(x.fmt() for x in self.typs)}>' | |
@dataclass(eq=False) | |
class Array: | |
typ: Type | |
def fmt(self) -> str: | |
return f'array<{self.typ.fmt()}>' | |
@dataclass(eq=False) | |
class FArray: | |
typ: Scalar | |
def fmt(self) -> str: | |
return f'array<{self.typ.fmt()}>' | |
Type = Scalar | Tuple | Array | |
FType = Scalar | FArray | |
def _array_adjust(typ: Type) -> int: | |
while isinstance(typ, Tuple): | |
typ = typ.typs[0] | |
lmost_is_subarray = isinstance(typ, Array) | |
return 2 if lmost_is_subarray else 1 | |
# XXX: INVARIANT: NODES IN THE TYPE MUST BE DISTINCT! | |
Map = dict[Type, tuple[int, bool]] | |
def translate_type(typ: Type) -> tuple[tuple[FType, ...], Map]: | |
typs: list[FType] = [] | |
map: Map = {} | |
def trans(typ: Type, in_array: bool) -> None: | |
start = len(typs) | |
if isinstance(typ, Scalar): | |
nt: FType = FArray(typ) if in_array else typ | |
typs.append(nt) | |
elif isinstance(typ, Array): | |
if in_array: | |
typs.extend([FArray(Scalar('int64'))]) | |
trans(typ.typ, in_array=True) | |
elif isinstance(typ, Tuple): | |
for t in typ.typs: | |
trans(t, in_array=in_array) | |
map[typ] = (start, in_array) | |
trans(typ, in_array=False) | |
return tuple(typs), map | |
def encode(typ: Type, ntyps: tuple[FType, ...], map: Map, data: Any) -> tuple[Any, ...]: | |
args: list[Any] = [0 if isinstance(t, Scalar) else [] for t in ntyps] | |
def enc(typ: Type, data: Any) -> None: | |
arg, in_array = map[typ] | |
if isinstance(typ, Scalar): | |
if in_array: | |
args[arg].append(data) | |
else: | |
args[arg] = data | |
elif isinstance(typ, Array): | |
assert isinstance(data, list) | |
if in_array: | |
if not args[arg]: | |
args[arg].append(0) | |
args[arg].append(args[arg][-1] + len(data)) | |
for val in data: | |
enc(typ.typ, val) | |
elif isinstance(typ, Tuple): | |
assert isinstance(data, tuple) | |
assert len(typ.typs) == len(data) | |
for styp, val in zip(typ.typs, data): | |
enc(styp, val) | |
enc(typ, data) | |
return tuple(args) | |
def naive_decode(typ: Type, map: Map, data: tuple[Any, ...]) -> Any: | |
def dec(typ: Type) -> Any: | |
arg, in_array = map[typ] | |
if isinstance(typ, Scalar): | |
return data[arg] | |
elif isinstance(typ, Array): | |
parts = dec(typ.typ) | |
if not in_array: | |
return parts | |
out = [] | |
for i in range(len(data[arg]) - 1): | |
out.append(parts[data[arg][i]:data[arg][i + 1]]) | |
return out | |
elif isinstance(typ, Tuple): | |
lparts = [dec(t) for t in typ.typs] | |
if in_array: | |
return [tuple(x) for x in zip(*lparts)] | |
else: | |
return tuple(lparts) | |
return dec(typ) | |
def decode(typ: Type, map: Map, data: tuple[Any, ...]) -> Any: | |
# I don't think this will *actually* be faster than naive_decode | |
# (since it will have way more python function calls), but it *is* | |
# linear time and matches the algorithm used in the edgeql | |
# decoder. | |
# Actually, naive_decode is I think also linear time? | |
# But naive decode does a lot more intermediate object creation. | |
# | |
# Actually, actually! Whether naive_decode is linear time depends | |
# on the object storage model! In a pointer based model it is, | |
# but in a model where the objects actually get included, it is | |
# *not*. | |
# What does Postgres do?? | |
def dec(typ: Type, idx: Optional[int]) -> Any: | |
arg, _ = map[typ] | |
if isinstance(typ, Scalar): | |
return data[arg][idx] if idx is not None else data[arg] | |
elif isinstance(typ, Array): | |
if idx is None: | |
lo = 0 | |
hi = len(data[arg]) - _array_adjust(typ.typ) + 1 | |
else: | |
lo = data[arg][idx] | |
hi = data[arg][idx+1] | |
return [dec(typ.typ, idx=i) for i in range(lo, hi)] | |
elif isinstance(typ, Tuple): | |
return tuple(dec(t, idx=idx) for t in typ.typs) | |
return dec(typ, idx=None) | |
def make_decoder(typ: Type, map: Map, ftypes: tuple[FType, ...]) -> str: | |
cnt = 0 | |
def mk_name(x: str) -> str: | |
nonlocal cnt | |
cnt += 1 | |
return f'{x}{cnt}' | |
# BS = '<json>' | |
BS = '' | |
def mk(typ: Type, idx: Optional[str]) -> str: # ? | |
arg, in_array = map[typ] | |
if isinstance(typ, Scalar): | |
tname = f'array<{typ.fmt()}>' if in_array else typ.fmt() | |
if idx is None: | |
return f'<{tname}>{BS}${arg}' | |
else: | |
return f'(<{tname}>{BS}${arg})[{idx}]' | |
elif isinstance(typ, Array): | |
a = f'(<array<int64>>${arg})' | |
# If the contents is just a scalar, then we can take | |
# values directly from the scalar array parameter, without | |
# needing to iterate over the array directly. | |
# This is an optimization, and not necessary for correctness. | |
if isinstance(typ.typ, Scalar): | |
sub = mk(typ.typ, idx=None) | |
# If we are in an array, do a slice! | |
if idx is not None: | |
sub = f'({sub})[{a}[{idx}]:{a}[{idx}+1]]' | |
return sub | |
inner_idx = mk_name('i') | |
sub = mk(typ.typ, idx=inner_idx) | |
if idx is None: | |
adjust = _array_adjust(typ.typ) | |
lo = '0' | |
hi = f'len(<{ftypes[arg].fmt()}>${arg}) - {adjust}' | |
else: | |
lo = f'{a}[{idx}]' | |
hi = f'{a}[{idx}+1]-1' | |
# lol at this formatting scheme. | |
grp = textwrap.dedent(f'''\ | |
array_agg((for {inner_idx} in _gen_series({lo}, {hi}) union ( | |
%s | |
)))''' | |
) % textwrap.indent(sub, ' ') | |
return grp | |
elif isinstance(typ, Tuple): | |
lparts = [mk(t, idx=idx) + ',' for t in typ.typs] | |
return f'({" ".join(str(p) for p in lparts)})' | |
return mk(typ, idx=None) | |
######### TESTING | |
def test(t1, data): | |
ts1, m1 = translate_type(t1) | |
print(ts1) | |
print(m1) | |
v1 = encode(t1, ts1, m1, data) | |
print(v1) | |
print() | |
d1 = decode(t1, m1, v1) | |
print(d1) | |
assert d1 == data | |
print() | |
print(f'select {make_decoder(t1, m1, ts1)};') | |
print() | |
for x in v1: | |
print(json.dumps(x)) | |
print() | |
t1 = Array(Tuple((Array(Scalar('str')),))) | |
t2 = Array(Tuple((t1, Scalar('str')))) | |
test_data: list[tuple[list[str]]] = [ | |
(['a'],), | |
(['b','c'],), | |
(['d','e','f'],), | |
] | |
test_data_2p: list[tuple[list[str]]] = [ | |
(['x','y','z','w'],), | |
(['g','h','i'],), | |
(['j','k'],), | |
(['l'],), | |
] | |
test_data2 = [(test_data, 'foo'), (test_data_2p, 'bar')] | |
# simpler | |
t3 = Array(Tuple((t1,))) | |
test_data3 = [(test_data,), (test_data_2p,)] | |
def go(): | |
test(t1, test_data) | |
test(t2, test_data2) | |
test(t3, test_data3) | |
# go() | |
################# | |
import hypothesis as h | |
import hypothesis.strategies as hs | |
typ = hs.recursive( | |
hs.builds(lambda _: Scalar('str'), hs.none()), | |
lambda children: ( | |
hs.builds(Array, children.filter(lambda x: not isinstance(x, Array))) | |
| hs.lists(children, min_size=1, max_size=8) | |
.map(lambda l: Tuple(tuple(l))) | |
) | |
) | |
def type_to_strategy(t): | |
if isinstance(t, Scalar): | |
# return hs.text(alphabet='abcdefghijklmnoqrstuvwxyz') | |
return hs.sampled_from('abcdefghijklmnoqrstuvwxyz') | |
# return hs.integers() | |
elif isinstance(t, Array): | |
return hs.lists(type_to_strategy(t.typ)) | |
elif isinstance(t, Tuple): | |
return hs.tuples(*[type_to_strategy(t) for t in t.typs]) | |
@hs.composite | |
def type_and_data(draw): | |
t = draw(typ) | |
d = draw(type_to_strategy(t)) | |
return (t, d) | |
@h.given(type_and_data()) | |
def test_encode_decode(td): | |
t, d = td | |
print() | |
print(t.fmt()) | |
print(d) | |
nts, m = translate_type(t) | |
encoded = encode(t, nts, m, d) | |
decoded = decode(t, m, encoded) | |
assert d == decoded, ( | |
d, | |
encoded, | |
decoded, | |
) | |
naive_decoded = naive_decode(t, m, encoded) | |
assert d == naive_decoded, ( | |
d, | |
encoded, | |
decoded, | |
) | |
print('PASS') | |
_conn = None | |
def get_conn(): | |
import edgedb | |
global _conn | |
if not _conn: | |
_conn = edgedb.create_client( | |
port=5656, tls_security='insecure' | |
) | |
return _conn | |
def _test_edgeql(t, d): | |
print() | |
print(t.fmt()) | |
# print(t) | |
print(d) | |
nts, m = translate_type(t) | |
encoded = encode(t, nts, m, d) | |
query = make_decoder(t, m, nts) | |
db = get_conn() | |
print("Q", query) | |
# print('args', encoded, nts) | |
tb = time.time() | |
db.query_single(f'select {query}', *encoded) | |
t0 = time.time() | |
decoded = db.query_single(f'select {query}', *encoded) | |
t1 = time.time() | |
assert d == decoded, ( | |
d, | |
encoded, | |
decoded, | |
) | |
# assert t1 - t0 < 1.0 # LOL | |
print(f'PASS {t0-tb:.3f} {t1-t0:.3f}') | |
@h.settings(deadline=None) | |
@h.given(type_and_data()) | |
def test_edgeql(td): | |
_test_edgeql(*td) | |
# _test_edgeql(t1, test_data) | |
test_encode_decode() | |
test_edgeql() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment