Skip to content

Instantly share code, notes, and snippets.

@sjwiesman
Created July 7, 2025 19:42
Show Gist options
  • Save sjwiesman/a2699a51c3782203306f53dd3a978183 to your computer and use it in GitHub Desktop.
Save sjwiesman/a2699a51c3782203306f53dd3a978183 to your computer and use it in GitHub Desktop.
from dataclasses import dataclass
from enum import Enum
from typing import Optional, AsyncGenerator, TypeVar, Callable, Generator
from psycopg import sql
from psycopg.rows import dict_row
from psycopg_pool import AsyncConnectionPool
t = TypeVar("t")
@dataclass(slots=True)
class Snapshot:
mz_timestamp: int
@dataclass(slots=True)
class Progress:
mz_timestamp: int
@dataclass(slots=True)
class Upsert:
mz_timestamp: int
row: dict
@dataclass(slots=True)
class Delete:
mz_timestamp: int
row: dict
@dataclass(slots=True)
class Change:
mz_timestamp: int
before: Optional[dict]
after: Optional[dict]
type UpsertEvent = Snapshot | Progress | Upsert | Delete
type DebeziumEvent = Snapshot | Progress | Change
class Envelope(Enum):
UPSERT = 1
DEBEZIUM = 2
def sql(self):
if self == Envelope.UPSERT:
return sql.SQL("ENVELOPE UPSERT (KEY ({}))")
elif self == Envelope.DEBEZIUM:
return sql.SQL("ENVELOPE DEBEZIUM (KEY {}))")
else:
raise NotImplementedError
def _subscribe(
query: sql.Composable,
keys: list[str],
envelope: Optional[Envelope],
as_of: Optional[int],
up_to: Optional[int],
):
return sql.SQL("""
SUBSCRIBE (
{query}
)
{envelope_clause}
WITH (PROGRESS)
{as_of_clause} {up_to_clause}
""").format(
query=query,
envelope_clause=(
envelope.sql().format(
sql.SQL(", ").join(map(sql.Identifier, keys)),
)
if keys
else sql.SQL("")
),
as_of_clause=(
sql.SQL("AS OF {}").format(sql.Literal(as_of))
if as_of is not None
else sql.SQL("")
),
up_to_clause=(
sql.SQL("UP TO {}").format(sql.Literal(up_to))
if up_to is not None
else sql.SQL("")
),
)
class Subscription:
def __init__(self, conninfo: str, kwargs: dict = None):
self._conninfo = conninfo
self._kwargs = kwargs
self._pool: Optional[AsyncConnectionPool] = None
async def __aenter__(self):
self._pool = await AsyncConnectionPool(
conninfo=self._conninfo,
kwargs=self._kwargs,
configure=lambda conn: conn.set_autocommit(True),
).__aenter__()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self._pool:
await self._pool.__aexit__(exc_type, exc_val, exc_tb)
async def get_mz_environment_id(self):
async with self._pool.connection() as conn:
async with conn.cursor(row_factory=dict_row) as cur:
results = await cur.execute("SELECT mz_environment_id() AS id")
row = await results.fetchone()
return row["id"]
def upsert(
self,
query: sql.Composable,
keys: list[str],
*,
cluster: Optional[str] = None,
as_of: Optional[int] = None,
up_to: Optional[int] = None,
fetch_size: Optional[int] = None,
) -> AsyncGenerator[UpsertEvent, None]:
def process_monotonic(row):
diff = int(row.pop("mz_diff"))
timestamp = int(row.pop("mz_timestamp"))
upsert = diff > 0
multiplicity = abs(diff)
for _ in range(multiplicity):
if upsert:
yield Upsert(mz_timestamp=timestamp, row=row)
else:
yield Delete(mz_timestamp=timestamp, row=row)
def process_keyed(row):
state = row.pop("mz_state")
timestamp = int(row.pop("mz_timestamp"))
if state == "upsert":
yield Upsert(mz_timestamp=timestamp, row=row)
elif state == "delete":
yield Delete(mz_timestamp=timestamp, row=row)
elif state == "key_violation":
raise RuntimeError(
f"key violation: multiple rows with primary key {row}"
)
else:
raise RuntimeError(f"Unknown state: {state}")
return self._run(
as_of,
cluster,
fetch_size,
_subscribe(query, keys, Envelope.UPSERT, as_of, up_to),
process_monotonic if keys is None or len(keys) == 0 else process_keyed,
)
def debezium(
self,
query: sql.Composable,
keys: list[str],
*,
cluster: Optional[str] = None,
as_of: Optional[int] = None,
up_to: Optional[int] = None,
fetch_size: Optional[int] = None,
):
def process_monotonic(row):
diff = int(row.pop("mz_diff"))
timestamp = int(row.pop("mz_timestamp"))
before = diff > 0
multiplicity = abs(diff)
for _ in range(multiplicity):
if before:
yield Change(mz_timestamp=timestamp, before=row.copy(), after=None)
else:
yield Change(mz_timestamp=timestamp, before=None, after=row.copy())
def process_keyed(row):
state = row.pop("mz_state")
timestamp = int(row.pop("mz_timestamp"))
before = {k: row.pop(k) for k in keys}
after = before.copy()
for k, v in before.items():
if k.startswith("before_"):
before[k.removeprefix("before_")] = v
else:
after[k.removeprefix("after_")] = v
if state == "insert":
yield Change(mz_timestamp=timestamp, before=before, after=None)
elif state == "upsert":
yield Change(mz_timestamp=timestamp, before=before, after=after)
elif state == "delete":
yield Change(mz_timestamp=timestamp, before=None, after=None)
elif state == "key_violation":
raise RuntimeError(
f"key violation: multiple rows with primary key {row}"
)
else:
raise RuntimeError(f"Unknown state: {state}")
return self._run(
as_of,
cluster,
fetch_size,
_subscribe(query, keys, Envelope.DEBEZIUM, as_of, up_to),
process_monotonic if keys is None or len(keys) == 0 else process_keyed,
)
async def _run(
self,
as_of: Optional[int],
cluster: Optional[str],
fetch_size: Optional[int],
subscribe_sql: sql.Composed,
process: Callable[[dict], Generator[t, None]],
) -> AsyncGenerator[t, None]:
if not self._pool:
raise RuntimeError("connection not open")
async with self._pool.connection() as conn:
async with conn.cursor(row_factory=dict_row) as cur:
if cluster:
await cur.execute(
sql.SQL("SET CLUSTER = {cluster}").format(
cluster=sql.Identifier(cluster)
)
)
declare_cursor = sql.SQL("DECLARE c CURSOR FOR {subscribe_sql}").format(
subscribe_sql=subscribe_sql
)
fetch = (
sql.SQL("FETCH {fetch_size} c").format(fetch_size=fetch_size)
if fetch_size
else sql.SQL("FETCH ALL c")
)
await cur.execute("BEGIN")
await cur.execute(declare_cursor)
if not as_of:
results = await cur.execute("FETCH 1 c")
snapshot = await results.fetchone()
if not snapshot or not bool(snapshot["mz_progressed"]):
raise RuntimeError("Missing starting progress message")
yield Snapshot(mz_timestamp=int(snapshot["mz_timestamp"]))
while True:
rows = await cur.execute(fetch)
async for row in rows:
progress = row.pop("mz_progressed")
if bool(progress):
timestamp = row.pop("mz_timestamp")
yield Progress(mz_timestamp=int(timestamp))
else:
for event in process(row):
yield event
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment