Created
July 7, 2025 19:42
-
-
Save sjwiesman/a2699a51c3782203306f53dd3a978183 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from dataclasses import dataclass | |
| from enum import Enum | |
| from typing import Optional, AsyncGenerator, TypeVar, Callable, Generator | |
| from psycopg import sql | |
| from psycopg.rows import dict_row | |
| from psycopg_pool import AsyncConnectionPool | |
| t = TypeVar("t") | |
| @dataclass(slots=True) | |
| class Snapshot: | |
| mz_timestamp: int | |
| @dataclass(slots=True) | |
| class Progress: | |
| mz_timestamp: int | |
| @dataclass(slots=True) | |
| class Upsert: | |
| mz_timestamp: int | |
| row: dict | |
| @dataclass(slots=True) | |
| class Delete: | |
| mz_timestamp: int | |
| row: dict | |
| @dataclass(slots=True) | |
| class Change: | |
| mz_timestamp: int | |
| before: Optional[dict] | |
| after: Optional[dict] | |
| type UpsertEvent = Snapshot | Progress | Upsert | Delete | |
| type DebeziumEvent = Snapshot | Progress | Change | |
| class Envelope(Enum): | |
| UPSERT = 1 | |
| DEBEZIUM = 2 | |
| def sql(self): | |
| if self == Envelope.UPSERT: | |
| return sql.SQL("ENVELOPE UPSERT (KEY ({}))") | |
| elif self == Envelope.DEBEZIUM: | |
| return sql.SQL("ENVELOPE DEBEZIUM (KEY {}))") | |
| else: | |
| raise NotImplementedError | |
| def _subscribe( | |
| query: sql.Composable, | |
| keys: list[str], | |
| envelope: Optional[Envelope], | |
| as_of: Optional[int], | |
| up_to: Optional[int], | |
| ): | |
| return sql.SQL(""" | |
| SUBSCRIBE ( | |
| {query} | |
| ) | |
| {envelope_clause} | |
| WITH (PROGRESS) | |
| {as_of_clause} {up_to_clause} | |
| """).format( | |
| query=query, | |
| envelope_clause=( | |
| envelope.sql().format( | |
| sql.SQL(", ").join(map(sql.Identifier, keys)), | |
| ) | |
| if keys | |
| else sql.SQL("") | |
| ), | |
| as_of_clause=( | |
| sql.SQL("AS OF {}").format(sql.Literal(as_of)) | |
| if as_of is not None | |
| else sql.SQL("") | |
| ), | |
| up_to_clause=( | |
| sql.SQL("UP TO {}").format(sql.Literal(up_to)) | |
| if up_to is not None | |
| else sql.SQL("") | |
| ), | |
| ) | |
| class Subscription: | |
| def __init__(self, conninfo: str, kwargs: dict = None): | |
| self._conninfo = conninfo | |
| self._kwargs = kwargs | |
| self._pool: Optional[AsyncConnectionPool] = None | |
| async def __aenter__(self): | |
| self._pool = await AsyncConnectionPool( | |
| conninfo=self._conninfo, | |
| kwargs=self._kwargs, | |
| configure=lambda conn: conn.set_autocommit(True), | |
| ).__aenter__() | |
| return self | |
| async def __aexit__(self, exc_type, exc_val, exc_tb): | |
| if self._pool: | |
| await self._pool.__aexit__(exc_type, exc_val, exc_tb) | |
| async def get_mz_environment_id(self): | |
| async with self._pool.connection() as conn: | |
| async with conn.cursor(row_factory=dict_row) as cur: | |
| results = await cur.execute("SELECT mz_environment_id() AS id") | |
| row = await results.fetchone() | |
| return row["id"] | |
| def upsert( | |
| self, | |
| query: sql.Composable, | |
| keys: list[str], | |
| *, | |
| cluster: Optional[str] = None, | |
| as_of: Optional[int] = None, | |
| up_to: Optional[int] = None, | |
| fetch_size: Optional[int] = None, | |
| ) -> AsyncGenerator[UpsertEvent, None]: | |
| def process_monotonic(row): | |
| diff = int(row.pop("mz_diff")) | |
| timestamp = int(row.pop("mz_timestamp")) | |
| upsert = diff > 0 | |
| multiplicity = abs(diff) | |
| for _ in range(multiplicity): | |
| if upsert: | |
| yield Upsert(mz_timestamp=timestamp, row=row) | |
| else: | |
| yield Delete(mz_timestamp=timestamp, row=row) | |
| def process_keyed(row): | |
| state = row.pop("mz_state") | |
| timestamp = int(row.pop("mz_timestamp")) | |
| if state == "upsert": | |
| yield Upsert(mz_timestamp=timestamp, row=row) | |
| elif state == "delete": | |
| yield Delete(mz_timestamp=timestamp, row=row) | |
| elif state == "key_violation": | |
| raise RuntimeError( | |
| f"key violation: multiple rows with primary key {row}" | |
| ) | |
| else: | |
| raise RuntimeError(f"Unknown state: {state}") | |
| return self._run( | |
| as_of, | |
| cluster, | |
| fetch_size, | |
| _subscribe(query, keys, Envelope.UPSERT, as_of, up_to), | |
| process_monotonic if keys is None or len(keys) == 0 else process_keyed, | |
| ) | |
| def debezium( | |
| self, | |
| query: sql.Composable, | |
| keys: list[str], | |
| *, | |
| cluster: Optional[str] = None, | |
| as_of: Optional[int] = None, | |
| up_to: Optional[int] = None, | |
| fetch_size: Optional[int] = None, | |
| ): | |
| def process_monotonic(row): | |
| diff = int(row.pop("mz_diff")) | |
| timestamp = int(row.pop("mz_timestamp")) | |
| before = diff > 0 | |
| multiplicity = abs(diff) | |
| for _ in range(multiplicity): | |
| if before: | |
| yield Change(mz_timestamp=timestamp, before=row.copy(), after=None) | |
| else: | |
| yield Change(mz_timestamp=timestamp, before=None, after=row.copy()) | |
| def process_keyed(row): | |
| state = row.pop("mz_state") | |
| timestamp = int(row.pop("mz_timestamp")) | |
| before = {k: row.pop(k) for k in keys} | |
| after = before.copy() | |
| for k, v in before.items(): | |
| if k.startswith("before_"): | |
| before[k.removeprefix("before_")] = v | |
| else: | |
| after[k.removeprefix("after_")] = v | |
| if state == "insert": | |
| yield Change(mz_timestamp=timestamp, before=before, after=None) | |
| elif state == "upsert": | |
| yield Change(mz_timestamp=timestamp, before=before, after=after) | |
| elif state == "delete": | |
| yield Change(mz_timestamp=timestamp, before=None, after=None) | |
| elif state == "key_violation": | |
| raise RuntimeError( | |
| f"key violation: multiple rows with primary key {row}" | |
| ) | |
| else: | |
| raise RuntimeError(f"Unknown state: {state}") | |
| return self._run( | |
| as_of, | |
| cluster, | |
| fetch_size, | |
| _subscribe(query, keys, Envelope.DEBEZIUM, as_of, up_to), | |
| process_monotonic if keys is None or len(keys) == 0 else process_keyed, | |
| ) | |
| async def _run( | |
| self, | |
| as_of: Optional[int], | |
| cluster: Optional[str], | |
| fetch_size: Optional[int], | |
| subscribe_sql: sql.Composed, | |
| process: Callable[[dict], Generator[t, None]], | |
| ) -> AsyncGenerator[t, None]: | |
| if not self._pool: | |
| raise RuntimeError("connection not open") | |
| async with self._pool.connection() as conn: | |
| async with conn.cursor(row_factory=dict_row) as cur: | |
| if cluster: | |
| await cur.execute( | |
| sql.SQL("SET CLUSTER = {cluster}").format( | |
| cluster=sql.Identifier(cluster) | |
| ) | |
| ) | |
| declare_cursor = sql.SQL("DECLARE c CURSOR FOR {subscribe_sql}").format( | |
| subscribe_sql=subscribe_sql | |
| ) | |
| fetch = ( | |
| sql.SQL("FETCH {fetch_size} c").format(fetch_size=fetch_size) | |
| if fetch_size | |
| else sql.SQL("FETCH ALL c") | |
| ) | |
| await cur.execute("BEGIN") | |
| await cur.execute(declare_cursor) | |
| if not as_of: | |
| results = await cur.execute("FETCH 1 c") | |
| snapshot = await results.fetchone() | |
| if not snapshot or not bool(snapshot["mz_progressed"]): | |
| raise RuntimeError("Missing starting progress message") | |
| yield Snapshot(mz_timestamp=int(snapshot["mz_timestamp"])) | |
| while True: | |
| rows = await cur.execute(fetch) | |
| async for row in rows: | |
| progress = row.pop("mz_progressed") | |
| if bool(progress): | |
| timestamp = row.pop("mz_timestamp") | |
| yield Progress(mz_timestamp=int(timestamp)) | |
| else: | |
| for event in process(row): | |
| yield event |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment