Created
June 24, 2025 17:22
-
-
Save msullivan/1bde300d75e9530d845e578ce3ae2e1b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from dataclasses import dataclass | |
import json | |
import time | |
import textwrap | |
from typing import Any, Callable | |
@dataclass(eq=False) | |
class Scalar: | |
name: str = 'T' | |
def fmt(self) -> str: | |
return self.name | |
@dataclass(eq=False) | |
class Tuple: | |
typs: tuple[Type, ...] | |
def fmt(self) -> str: | |
return f'tuple<{", ".join(x.fmt() for x in self.typs)}>' | |
@dataclass(eq=False) | |
class Array: | |
typ: Type | |
def fmt(self) -> str: | |
return f'array<{self.typ.fmt()}>' | |
@dataclass(eq=False) | |
class FArray: | |
typ: Scalar | |
def fmt(self) -> str: | |
return f'array<{self.typ.fmt()}>' | |
Type = Scalar | Tuple | Array | |
FType = Scalar | FArray | |
def _array_adjust(typ: Type) -> int: | |
while isinstance(typ, Tuple): | |
typ = typ.typs[0] | |
lmost_is_subarray = isinstance(typ, Array) | |
return 2 if lmost_is_subarray else 1 | |
# XXX: INVARIANT: NODES IN THE TYPE MUST BE DISTINCT! | |
Map = dict[Type, tuple[int, bool]] | |
def translate_type(typ: Type) -> tuple[tuple[FType, ...], Map]: | |
typs: list[FType] = [] | |
map: Map = {} | |
def trans(typ: Type, in_array: bool) -> None: | |
start = len(typs) | |
if isinstance(typ, Scalar): | |
nt: FType = FArray(typ) if in_array else typ | |
typs.append(nt) | |
elif isinstance(typ, Array): | |
if in_array: | |
typs.extend([FArray(Scalar('int32'))]) | |
trans(typ.typ, in_array=True) | |
elif isinstance(typ, Tuple): | |
for t in typ.typs: | |
trans(t, in_array=in_array) | |
map[typ] = (start, in_array) | |
trans(typ, in_array=False) | |
return tuple(typs), map | |
def encode(typ: Type, ntyps: tuple[FType, ...], map: Map, data: Any) -> tuple[Any, ...]: | |
args: list[Any] = [0 if isinstance(t, Scalar) else [] for t in ntyps] | |
def enc(typ: Type, data: Any) -> None: | |
arg, in_array = map[typ] | |
if isinstance(typ, Scalar): | |
if in_array: | |
args[arg].append(data) | |
else: | |
args[arg] = data | |
elif isinstance(typ, Array): | |
assert isinstance(data, list) | |
if in_array: | |
if not args[arg]: | |
args[arg].append(0) | |
args[arg].append(args[arg][-1] + len(data)) | |
for val in data: | |
enc(typ.typ, val) | |
elif isinstance(typ, Tuple): | |
assert isinstance(data, tuple) | |
assert len(typ.typs) == len(data) | |
for styp, val in zip(typ.typs, data): | |
enc(styp, val) | |
enc(typ, data) | |
return tuple(args) | |
def naive_decode(typ: Type, map: Map, data: tuple[Any, ...]) -> Any: | |
def dec(typ: Type) -> Any: | |
arg, in_array = map[typ] | |
if isinstance(typ, Scalar): | |
return data[arg] | |
elif isinstance(typ, Array): | |
parts = dec(typ.typ) | |
if not in_array: | |
return parts | |
out = [] | |
for i in range(len(data[arg]) - 1): | |
out.append(parts[data[arg][i]:data[arg][i + 1]]) | |
return out | |
elif isinstance(typ, Tuple): | |
lparts = [dec(t) for t in typ.typs] | |
if in_array: | |
return [tuple(x) for x in zip(*lparts)] | |
else: | |
return tuple(lparts) | |
return dec(typ) | |
def decode(typ: Type, map: Map, data: tuple[Any, ...]) -> Any: | |
# I don't think this will *actually* be faster than naive_decode | |
# (since it will have way more python function calls), but it *is* | |
# linear time and matches the algorithm used in the edgeql | |
# decoder. | |
# Actually, naive_decode is I think also linear time? | |
# But naive decode does a lot more intermediate object creation. | |
# | |
# Actually, actually! Whether naive_decode is linear time depends | |
# on the object storage model! In a pointer based model it is, | |
# but in a model where the objects actually get included, it is | |
# *not*. | |
# What does Postgres do?? | |
def dec(typ: Type, idx: Optional[int]) -> Any: | |
arg, _ = map[typ] | |
if isinstance(typ, Scalar): | |
return data[arg][idx] if idx is not None else data[arg] | |
elif isinstance(typ, Array): | |
if idx is None: | |
lo = 0 | |
hi = len(data[arg]) - _array_adjust(typ.typ) + 1 | |
else: | |
lo = data[arg][idx] | |
hi = data[arg][idx+1] | |
return [dec(typ.typ, idx=i) for i in range(lo, hi)] | |
elif isinstance(typ, Tuple): | |
return tuple(dec(t, idx=idx) for t in typ.typs) | |
return dec(typ, idx=None) | |
def make_decoder(typ: Type, map: Map, ftypes: tuple[FType, ...]) -> str: | |
cnt = 0 | |
def mk_name(x: str) -> str: | |
nonlocal cnt | |
cnt += 1 | |
return f'{x}{cnt}' | |
# BS = '<json>' | |
BS = '' | |
def mk(typ: Type, idx: Optional[str]) -> str: # ? | |
arg, in_array = map[typ] | |
if isinstance(typ, Scalar): | |
tname = f'array<{typ.fmt()}>' if in_array else typ.fmt() | |
if idx is None: | |
return f'<{tname}>{BS}${arg}' | |
else: | |
return f'(<{tname}>{BS}${arg})[{idx}]' | |
elif isinstance(typ, Array): | |
a = f'(<array<int32>>${arg})' | |
# If the contents is just a scalar, then we can take | |
# values directly from the scalar array parameter, without | |
# needing to iterate over the array directly. | |
# This is an optimization, and not necessary for correctness. | |
if isinstance(typ.typ, Scalar): | |
sub = mk(typ.typ, idx=None) | |
# If we are in an array, do a slice! | |
if idx is not None: | |
sub = f'({sub})[{a}[{idx}]:{a}[{idx}+1]]' | |
return sub | |
inner_idx = mk_name('i') | |
sub = mk(typ.typ, idx=inner_idx) | |
if idx is None: | |
adjust = _array_adjust(typ.typ) | |
lo = '0' | |
hi = f'len(<{ftypes[arg].fmt()}>${arg}) - {adjust}' | |
else: | |
lo = f'{a}[{idx}]' | |
hi = f'{a}[{idx}+1]-1' | |
# lol at this formatting scheme. | |
grp = textwrap.dedent(f'''\ | |
array_agg((for {inner_idx} in _gen_series({lo}, {hi}) union ( | |
%s | |
)))''' | |
) % textwrap.indent(sub, ' ') | |
return grp | |
elif isinstance(typ, Tuple): | |
lparts = [mk(t, idx=idx) + ',' for t in typ.typs] | |
return f'({" ".join(str(p) for p in lparts)})' | |
return mk(typ, idx=None) | |
######### TESTING | |
def test(t1, data): | |
ts1, m1 = translate_type(t1) | |
print(ts1) | |
print(m1) | |
v1 = encode(t1, ts1, m1, data) | |
print(v1) | |
print() | |
d1 = decode(t1, m1, v1) | |
print(d1) | |
assert d1 == data | |
print() | |
print(f'select {make_decoder(t1, m1, ts1)};') | |
print() | |
for x in v1: | |
print(json.dumps(x)) | |
print() | |
t1 = Array(Tuple((Array(Scalar('str')),))) | |
t2 = Array(Tuple((t1, Scalar('str')))) | |
test_data: list[tuple[list[str]]] = [ | |
(['a'],), | |
(['b','c'],), | |
(['d','e','f'],), | |
] | |
test_data_2p: list[tuple[list[str]]] = [ | |
(['x','y','z','w'],), | |
(['g','h','i'],), | |
(['j','k'],), | |
(['l'],), | |
] | |
test_data2 = [(test_data, 'foo'), (test_data_2p, 'bar')] | |
# simpler | |
t3 = Array(Tuple((t1,))) | |
test_data3 = [(test_data,), (test_data_2p,)] | |
# more complex | |
t4 = Array(Tuple((Scalar('str'), Array(Scalar('int64')),))) | |
def mk_testdata4(n): | |
return [ | |
(f'name_{i}', [j**2 for j in range(i)]) | |
for i in range(n) | |
] | |
test_data4 = mk_testdata4(200) | |
# more complex | |
t5 = Array(Tuple((Scalar('int64'),))) | |
def mk_testdata5(n): | |
return [ | |
(i,) | |
for i in range(n) | |
] | |
def go(): | |
test(t1, test_data) | |
test(t2, test_data2) | |
test(t3, test_data3) | |
# go() | |
################# | |
import hypothesis as h | |
import hypothesis.strategies as hs | |
typ = hs.recursive( | |
hs.builds(lambda _: Scalar('str'), hs.none()), | |
lambda children: ( | |
hs.builds(Array, children.filter(lambda x: not isinstance(x, Array))) | |
| hs.lists(children, min_size=1, max_size=8) | |
.map(lambda l: Tuple(tuple(l))) | |
) | |
) | |
array_typ = typ.filter(lambda x: isinstance(x, Array)) # ONLY TOP ARRAYS | |
def type_to_strategy(t): | |
if isinstance(t, Scalar): | |
# return hs.text(alphabet='abcdefghijklmnoqrstuvwxyz') | |
return hs.sampled_from('abcdefghijklmnoqrstuvwxyz') | |
# return hs.integers() | |
elif isinstance(t, Array): | |
return hs.lists(type_to_strategy(t.typ)) | |
elif isinstance(t, Tuple): | |
return hs.tuples(*[type_to_strategy(t) for t in t.typs]) | |
@hs.composite | |
def type_and_data(draw): | |
# t = draw(typ) | |
t = draw(array_typ) | |
d = draw(type_to_strategy(t)) | |
return (t, d) | |
@h.given(type_and_data()) | |
def test_encode_decode(td): | |
t, d = td | |
print() | |
print(t.fmt()) | |
print(d) | |
nts, m = translate_type(t) | |
encoded = encode(t, nts, m, d) | |
decoded = decode(t, m, encoded) | |
assert d == decoded, ( | |
d, | |
encoded, | |
decoded, | |
) | |
naive_decoded = naive_decode(t, m, encoded) | |
assert d == naive_decoded, ( | |
d, | |
encoded, | |
decoded, | |
) | |
print('PASS') | |
_conn = None | |
def get_conn(): | |
import edgedb | |
global _conn | |
if not _conn: | |
_conn = edgedb.create_client( | |
port=5656, tls_security='insecure' | |
) | |
return _conn | |
def _warm_cache(db, q, *args): | |
import edgedb | |
try: | |
db.query_single(q, *args) | |
except edgedb.QueryArgumentError: | |
pass | |
def _test_edgeql(t, d): | |
print() | |
print(t.fmt()) | |
print(t) | |
if len(repr(d)) < 3000: | |
print(d) | |
nts, m = translate_type(t) | |
query = make_decoder(t, m, nts) | |
db = get_conn() | |
print("Q", query) | |
# print('args', encoded, nts) | |
te = time.time() | |
encoded = encode(t, nts, m, d) | |
tee = time.time() | |
print(f'PASS e:{tee-te:.3f}') | |
tb = time.time() | |
# db.query_single(f'select {query}', *encoded) | |
_warm_cache(db, f'select {query}', []) | |
t0 = time.time() | |
# encoded = encode(t, nts, m, d) | |
decoded = db.query_single(f'select {query}', *encoded) | |
t1 = time.time() | |
assert d == decoded, ( | |
d, | |
encoded, | |
decoded, | |
) | |
# assert t1 - t0 < 1.0 # LOL | |
# db.query_single(f'select <{t.fmt()}><json>$0', json.dumps(d)) | |
# t0j = time.time() | |
# jdecoded = db.query_single(f'select <{t.fmt()}><json>$0', json.dumps(d)) | |
# t1j = time.time() | |
# assert d == jdecoded, ( | |
# d, | |
# jdecoded, | |
# ) | |
print(f'PASS e:{tee-te:.3f} {t0-tb:.3f} {t1-t0:.3f}') | |
def _test_edgeql_real(t, d): | |
print() | |
print(t.fmt()) | |
if len(repr(d)) < 3000: | |
print(d) | |
nts, m = translate_type(t) | |
db = get_conn() | |
# print('args', encoded, nts) | |
tb = time.time() | |
_warm_cache(db, f'select <{t.fmt()}>$x') | |
t0 = time.time() | |
# encoded = encode(t, nts, m, d) | |
decoded = db.query_single(f'select <{t.fmt()}>$x', x=d) | |
t1 = time.time() | |
assert d == decoded, ( | |
d, | |
decoded, | |
) | |
print(f'PASS {t0-tb:.3f} {t1-t0:.3f}') | |
@h.settings(deadline=None) | |
@h.given(type_and_data()) | |
def test_edgeql(td): | |
_test_edgeql(*td) | |
@h.settings(deadline=None) | |
@h.given(type_and_data()) | |
def test_edgeql_real(td): | |
_test_edgeql_real(*td) | |
# _test_edgeql(t1, test_data) | |
_test_edgeql(t4, mk_testdata4(2000)) | |
_test_edgeql_real(t4, mk_testdata4(2000)) | |
_test_edgeql_real(t5, mk_testdata5(1_000_000)) | |
# test_encode_decode() | |
# test_edgeql() | |
# test_edgeql_real() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment