Last active
October 15, 2025 20:54
-
-
Save RPG-fan/b6d578e45712ae9467b05d6ac4e8dbc6 to your computer and use it in GitHub Desktop.
A production-ready implementation of PostgreSQL logical replication using psycopg3. Works around psycopg3's missing replication support by using a raw socket bridge.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| PostgreSQL Logical Replication for psycopg3 | |
| A production-ready implementation of PostgreSQL logical replication using psycopg3. | |
| Works around psycopg3's missing replication support by using a raw socket bridge. | |
| Author: Richard Brandes | |
| Date: October 2025 | |
| License: MIT | |
| Related: https://github.com/psycopg/psycopg/issues/71 | |
| UPDATE (October 15, 2025): Added production stability features: | |
| - Reconnection logic with exponential backoff | |
| - LSN tracking across reconnections | |
| - Keepalive handling (both primary and unsolicited) | |
| - Proper initialization handshake | |
| - Graceful connection cleanup | |
| """ | |
| import asyncio | |
| import logging | |
| import socket as socket_module | |
| import struct | |
| from dataclasses import dataclass | |
| from datetime import datetime, timedelta, timezone | |
| from enum import Enum | |
| from typing import Any, AsyncIterator, Dict, List, Optional, Union | |
| from psycopg import AsyncConnection | |
| logger = logging.getLogger(__name__) | |
| # ============================================================================ | |
| # Utility Functions | |
| # ============================================================================ | |
| def lsn_to_string(lsn: int) -> str: | |
| """Convert an integer LSN to PostgreSQL's 'X/X' string format.""" | |
| if lsn == 0: | |
| return "0/0" | |
| high = (lsn >> 32) & 0xFFFFFFFF | |
| low = lsn & 0xFFFFFFFF | |
| return f"{high:X}/{low:X}" | |
| def string_to_lsn(lsn_str: str) -> int: | |
| """Convert PostgreSQL's 'X/X' LSN string format to integer.""" | |
| if lsn_str == "0/0": | |
| return 0 | |
| parts = lsn_str.split("/") | |
| if len(parts) != 2: | |
| raise ValueError(f"Invalid LSN format: {lsn_str}") | |
| high = int(parts[0], 16) | |
| low = int(parts[1], 16) | |
| return (high << 32) | low | |
| # ============================================================================ | |
| # Protocol Message Types and Data Structures | |
| # ============================================================================ | |
| class MessageType(Enum): | |
| """pgoutput message types from PostgreSQL logical replication protocol.""" | |
| BEGIN = "B" | |
| COMMIT = "C" | |
| ORIGIN = "O" | |
| RELATION = "R" | |
| TYPE = "Y" | |
| INSERT = "I" | |
| UPDATE = "U" | |
| DELETE = "D" | |
| TRUNCATE = "T" | |
| MESSAGE = "M" | |
| STREAM_START = "S" | |
| STREAM_STOP = "E" | |
| STREAM_COMMIT = "c" | |
| STREAM_ABORT = "A" | |
| @dataclass | |
| class ReplicationMessage: | |
| """Base replication message from WAL stream.""" | |
| payload: bytes | |
| data_start: int # LSN where this data starts | |
| data_end: int # LSN where this data ends | |
| send_time: int # Server send timestamp | |
| @dataclass | |
| class ColumnDefinition: | |
| """Column metadata from Relation message.""" | |
| flags: int | |
| name: str | |
| type_id: int | |
| type_modifier: int | |
| @property | |
| def is_key(self) -> bool: | |
| """Check if column is part of primary key.""" | |
| return bool(self.flags & 0x01) | |
| @dataclass | |
| class RelationMessage: | |
| """Relation (table schema) message.""" | |
| relation_id: int | |
| namespace: str | |
| relation_name: str | |
| replica_identity: int | |
| columns: List[ColumnDefinition] | |
| @dataclass | |
| class BeginMessage: | |
| """Transaction begin message.""" | |
| final_lsn: int | |
| commit_ts: datetime | |
| xid: int | |
| @dataclass | |
| class CommitMessage: | |
| """Transaction commit message.""" | |
| flags: int | |
| commit_lsn: int | |
| end_lsn: int | |
| commit_ts: datetime | |
| @dataclass | |
| class InsertMessage: | |
| """Insert operation message.""" | |
| relation_id: int | |
| new_tuple: Dict[str, Any] | |
| @dataclass | |
| class UpdateMessage: | |
| """Update operation message.""" | |
| relation_id: int | |
| old_tuple: Optional[Dict[str, Any]] | |
| new_tuple: Dict[str, Any] | |
| @dataclass | |
| class DeleteMessage: | |
| """Delete operation message.""" | |
| relation_id: int | |
| old_tuple: Optional[Dict[str, Any]] | |
| key_tuple: Optional[Dict[str, Any]] | |
| @dataclass | |
| class TruncateMessage: | |
| """Truncate operation message.""" | |
| options: int | |
| relation_ids: List[int] | |
| # ============================================================================ | |
| # Raw Socket Bridge - Direct PostgreSQL Wire Protocol Access | |
| # ============================================================================ | |
| class RawSocketReplicationBridge: | |
| """ | |
| Direct socket bridge for reading PostgreSQL replication data. | |
| This completely bypasses psycopg's COPY handling after connection setup, | |
| reading directly from the socket to avoid psycopg's incomplete replication support. | |
| Features: | |
| - Non-blocking socket I/O | |
| - Proper message framing and buffering | |
| - Keepalive message handling | |
| - Graceful cleanup | |
| """ | |
| def __init__(self, pgconn: Any) -> None: | |
| """ | |
| Initialize bridge with libpq connection. | |
| Args: | |
| pgconn: The libpq PGconn object from psycopg (typed as Any due to protocol limitation) | |
| """ | |
| self.pgconn: Any = pgconn | |
| self.socket_fd: int = int(pgconn.socket) | |
| self.socket_obj = socket_module.socket(fileno=self.socket_fd) | |
| self.socket_obj.setblocking(False) | |
| self.read_buffer = bytearray() | |
| self._initialized = False | |
| self._closed = False | |
| async def read_copy_data_message(self) -> Optional[bytes]: | |
| """ | |
| Read one complete CopyData message directly from socket. | |
| Returns: | |
| Message payload or None if stream ended | |
| """ | |
| while True: | |
| # Ensure we have at least the message header (1 byte type + 4 bytes length) | |
| while len(self.read_buffer) < 5: | |
| try: | |
| chunk = await asyncio.get_event_loop().sock_recv( | |
| self.socket_obj, 8192 | |
| ) | |
| if not chunk: | |
| return None | |
| self.read_buffer.extend(chunk) | |
| except BlockingIOError: | |
| await asyncio.sleep(0.01) | |
| continue | |
| except Exception as e: | |
| raise Exception(f"Socket read error: {e}") | |
| # Parse message header | |
| msg_type = chr(self.read_buffer[0]) | |
| msg_length = struct.unpack(">I", self.read_buffer[1:5])[0] | |
| # Message length includes itself (4 bytes) but not the type byte | |
| total_msg_size = 1 + msg_length | |
| # Read until we have the complete message | |
| while len(self.read_buffer) < total_msg_size: | |
| try: | |
| chunk = await asyncio.get_event_loop().sock_recv( | |
| self.socket_obj, 8192 | |
| ) | |
| if not chunk: | |
| return None | |
| self.read_buffer.extend(chunk) | |
| except BlockingIOError: | |
| await asyncio.sleep(0.01) | |
| continue | |
| # Extract complete message | |
| complete_message = bytes(self.read_buffer[:total_msg_size]) | |
| self.read_buffer = self.read_buffer[total_msg_size:] | |
| # Handle different message types | |
| if msg_type == "d": # CopyData - return payload | |
| return complete_message[5:] | |
| elif msg_type == "c": # CopyDone - end of stream | |
| logger.info("Received CopyDone message") | |
| return None | |
| elif msg_type == "f": # CopyFail - error | |
| raise Exception("Server sent CopyFail message") | |
| elif msg_type == "W": # CopyBothResponse - initial handshake | |
| # First message after START_REPLICATION - consume but don't return | |
| if not self._initialized: | |
| logger.debug("Received CopyBothResponse, initialization complete") | |
| self._initialized = True | |
| continue | |
| else: | |
| logger.warning("Received unexpected CopyBothResponse") | |
| continue | |
| else: | |
| logger.debug(f"Skipping unknown message type: {msg_type}") | |
| continue | |
| async def send_standby_status_update( | |
| self, received_lsn: int, flushed_lsn: int, applied_lsn: int, reply_requested: bool = False | |
| ) -> None: | |
| """ | |
| Send standby status update message directly to socket. | |
| This acknowledges WAL positions to PostgreSQL. | |
| Args: | |
| received_lsn: Last WAL position received | |
| flushed_lsn: Last WAL position flushed to disk | |
| applied_lsn: Last WAL position applied | |
| reply_requested: Whether to request immediate reply from primary | |
| """ | |
| if self._closed: | |
| logger.warning("Attempted to send feedback on closed socket") | |
| return | |
| # Build standby status update payload | |
| status_msg = struct.pack( | |
| ">cQQQQB", | |
| b"r", # Standby status update | |
| received_lsn, # Last WAL received | |
| flushed_lsn, # Last WAL flushed | |
| applied_lsn, # Last WAL applied | |
| 0, # Current timestamp (0 = server time) | |
| 1 if reply_requested else 0, # Reply requested flag | |
| ) | |
| # Wrap in CopyData message (type 'd' + length + payload) | |
| copy_data_msg = struct.pack(">cI", b"d", len(status_msg) + 4) + status_msg | |
| try: | |
| await asyncio.get_event_loop().sock_sendall( | |
| self.socket_obj, copy_data_msg | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to send feedback: {e}") | |
| raise | |
| def close(self) -> None: | |
| """Close the socket bridge gracefully.""" | |
| if not self._closed: | |
| self._closed = True | |
| # Don't actually close the socket - let psycopg handle that | |
| # Just mark as closed to prevent further operations | |
| logger.debug("Socket bridge closed") | |
| # ============================================================================ | |
| # pgoutput Binary Protocol Decoder | |
| # ============================================================================ | |
| class PgOutputDecoder: | |
| """ | |
| Decoder for PostgreSQL pgoutput binary protocol. | |
| Based on: https://www.postgresql.org/docs/current/protocol-logicalrep-message-formats.html | |
| """ | |
| def __init__(self) -> None: | |
| self.relations: Dict[int, RelationMessage] = {} | |
| self.types: Dict[int, str] = {} | |
| def decode( | |
| self, data: bytes | |
| ) -> Optional[ | |
| Union[ | |
| BeginMessage, | |
| CommitMessage, | |
| RelationMessage, | |
| InsertMessage, | |
| UpdateMessage, | |
| DeleteMessage, | |
| TruncateMessage, | |
| ] | |
| ]: | |
| """Decode a pgoutput binary message.""" | |
| if len(data) == 0: | |
| return None | |
| msg_type = chr(data[0]) | |
| try: | |
| if msg_type == MessageType.BEGIN.value: | |
| return self._decode_begin(data[1:]) | |
| elif msg_type == MessageType.COMMIT.value: | |
| return self._decode_commit(data[1:]) | |
| elif msg_type == MessageType.RELATION.value: | |
| return self._decode_relation(data[1:]) | |
| elif msg_type == MessageType.INSERT.value: | |
| return self._decode_insert(data[1:]) | |
| elif msg_type == MessageType.UPDATE.value: | |
| return self._decode_update(data[1:]) | |
| elif msg_type == MessageType.DELETE.value: | |
| return self._decode_delete(data[1:]) | |
| elif msg_type == MessageType.TRUNCATE.value: | |
| return self._decode_truncate(data[1:]) | |
| elif msg_type == MessageType.TYPE.value: | |
| return None | |
| elif msg_type == MessageType.ORIGIN.value: | |
| return None | |
| else: | |
| logger.debug(f"Unhandled message type: {msg_type}") | |
| return None | |
| except Exception as e: | |
| logger.error(f"Error decoding message type {msg_type}: {e}", exc_info=True) | |
| return None | |
| def _decode_begin(self, data: bytes) -> BeginMessage: | |
| """Decode BEGIN message.""" | |
| offset = 0 | |
| final_lsn = struct.unpack(">Q", data[offset : offset + 8])[0] | |
| offset += 8 | |
| commit_ts_raw = struct.unpack(">Q", data[offset : offset + 8])[0] | |
| offset += 8 | |
| commit_ts = self._pg_timestamp_to_datetime(commit_ts_raw) | |
| xid = struct.unpack(">I", data[offset : offset + 4])[0] | |
| return BeginMessage(final_lsn=final_lsn, commit_ts=commit_ts, xid=xid) | |
| def _decode_commit(self, data: bytes) -> CommitMessage: | |
| """Decode COMMIT message.""" | |
| offset = 0 | |
| flags = data[offset] | |
| offset += 1 | |
| commit_lsn = struct.unpack(">Q", data[offset : offset + 8])[0] | |
| offset += 8 | |
| end_lsn = struct.unpack(">Q", data[offset : offset + 8])[0] | |
| offset += 8 | |
| commit_ts_raw = struct.unpack(">Q", data[offset : offset + 8])[0] | |
| commit_ts = self._pg_timestamp_to_datetime(commit_ts_raw) | |
| return CommitMessage( | |
| flags=flags, commit_lsn=commit_lsn, end_lsn=end_lsn, commit_ts=commit_ts | |
| ) | |
| def _decode_relation(self, data: bytes) -> RelationMessage: | |
| """Decode RELATION message.""" | |
| offset = 0 | |
| relation_id = struct.unpack(">I", data[offset : offset + 4])[0] | |
| offset += 4 | |
| namespace, offset = self._read_string(data, offset) | |
| relation_name, offset = self._read_string(data, offset) | |
| replica_identity = data[offset] | |
| offset += 1 | |
| num_columns = struct.unpack(">H", data[offset : offset + 2])[0] | |
| offset += 2 | |
| columns: List[ColumnDefinition] = [] | |
| for _ in range(num_columns): | |
| flags = data[offset] | |
| offset += 1 | |
| col_name, offset = self._read_string(data, offset) | |
| type_id = struct.unpack(">I", data[offset : offset + 4])[0] | |
| offset += 4 | |
| type_mod = struct.unpack(">i", data[offset : offset + 4])[0] | |
| offset += 4 | |
| columns.append( | |
| ColumnDefinition( | |
| flags=flags, name=col_name, type_id=type_id, type_modifier=type_mod | |
| ) | |
| ) | |
| relation = RelationMessage( | |
| relation_id=relation_id, | |
| namespace=namespace, | |
| relation_name=relation_name, | |
| replica_identity=replica_identity, | |
| columns=columns, | |
| ) | |
| self.relations[relation_id] = relation | |
| return relation | |
| def _decode_insert(self, data: bytes) -> InsertMessage: | |
| """Decode INSERT message.""" | |
| offset = 0 | |
| relation_id = struct.unpack(">I", data[offset : offset + 4])[0] | |
| offset += 4 | |
| tuple_kind = chr(data[offset]) | |
| offset += 1 | |
| if tuple_kind != "N": | |
| raise ValueError(f"Unexpected tuple kind in INSERT: {tuple_kind}") | |
| new_tuple, _ = self._decode_tuple_data(data, offset, relation_id) | |
| return InsertMessage(relation_id=relation_id, new_tuple=new_tuple) | |
| def _decode_update(self, data: bytes) -> UpdateMessage: | |
| """Decode UPDATE message.""" | |
| offset = 0 | |
| relation_id = struct.unpack(">I", data[offset : offset + 4])[0] | |
| offset += 4 | |
| old_tuple: Optional[Dict[str, Any]] = None | |
| tuple_kind = chr(data[offset]) | |
| offset += 1 | |
| if tuple_kind == "O": | |
| old_tuple, offset = self._decode_tuple_data(data, offset, relation_id) | |
| tuple_kind = chr(data[offset]) | |
| offset += 1 | |
| elif tuple_kind == "K": | |
| old_tuple, offset = self._decode_tuple_data(data, offset, relation_id) | |
| tuple_kind = chr(data[offset]) | |
| offset += 1 | |
| if tuple_kind != "N": | |
| raise ValueError(f"Expected 'N' tuple kind, got: {tuple_kind}") | |
| new_tuple, _ = self._decode_tuple_data(data, offset, relation_id) | |
| return UpdateMessage( | |
| relation_id=relation_id, old_tuple=old_tuple, new_tuple=new_tuple | |
| ) | |
| def _decode_delete(self, data: bytes) -> DeleteMessage: | |
| """Decode DELETE message.""" | |
| offset = 0 | |
| relation_id = struct.unpack(">I", data[offset : offset + 4])[0] | |
| offset += 4 | |
| tuple_kind = chr(data[offset]) | |
| offset += 1 | |
| old_tuple: Optional[Dict[str, Any]] = None | |
| key_tuple: Optional[Dict[str, Any]] = None | |
| if tuple_kind == "O": | |
| old_tuple, _ = self._decode_tuple_data(data, offset, relation_id) | |
| elif tuple_kind == "K": | |
| key_tuple, _ = self._decode_tuple_data(data, offset, relation_id) | |
| return DeleteMessage( | |
| relation_id=relation_id, old_tuple=old_tuple, key_tuple=key_tuple | |
| ) | |
| def _decode_truncate(self, data: bytes) -> TruncateMessage: | |
| """Decode TRUNCATE message.""" | |
| offset = 0 | |
| num_relations = struct.unpack(">I", data[offset : offset + 4])[0] | |
| offset += 4 | |
| options = data[offset] | |
| offset += 1 | |
| relation_ids: List[int] = [] | |
| for _ in range(num_relations): | |
| rel_id = struct.unpack(">I", data[offset : offset + 4])[0] | |
| offset += 4 | |
| relation_ids.append(rel_id) | |
| return TruncateMessage(options=options, relation_ids=relation_ids) | |
| def _decode_tuple_data( | |
| self, data: bytes, offset: int, relation_id: int | |
| ) -> tuple[Dict[str, Any], int]: | |
| """Decode tuple data from INSERT/UPDATE/DELETE messages.""" | |
| relation = self.relations.get(relation_id) | |
| if not relation: | |
| raise ValueError(f"Unknown relation ID: {relation_id}") | |
| num_columns = struct.unpack(">H", data[offset : offset + 2])[0] | |
| offset += 2 | |
| if num_columns != len(relation.columns): | |
| raise ValueError( | |
| f"Column count mismatch: expected {len(relation.columns)}, got {num_columns}" | |
| ) | |
| tuple_data: Dict[str, Any] = {} | |
| for col_def in relation.columns: | |
| col_type = chr(data[offset]) | |
| offset += 1 | |
| if col_type == "n": | |
| tuple_data[col_def.name] = None | |
| elif col_type == "u": | |
| tuple_data[col_def.name] = None | |
| elif col_type == "t": | |
| col_len = struct.unpack(">I", data[offset : offset + 4])[0] | |
| offset += 4 | |
| col_value = data[offset : offset + col_len] | |
| offset += col_len | |
| tuple_data[col_def.name] = self._decode_value( | |
| col_value, col_def.type_id | |
| ) | |
| else: | |
| raise ValueError(f"Unknown column data type: {col_type}") | |
| return tuple_data, offset | |
| def _decode_value(self, value: bytes, type_id: int) -> Any: | |
| """Decode column value based on PostgreSQL type OID.""" | |
| text_value = value.decode("utf-8") | |
| if type_id == 16: # bool | |
| return text_value == "t" | |
| elif type_id in (20, 21, 23): # int8, int2, int4 | |
| return int(text_value) | |
| elif type_id in (700, 701): # float4, float8 | |
| return float(text_value) | |
| elif type_id in (1082,): # date | |
| return text_value | |
| elif type_id in (1114, 1184): # timestamp, timestamptz | |
| return text_value | |
| elif type_id == 2950: # uuid | |
| return text_value | |
| elif type_id == 114: # json | |
| import json | |
| return json.loads(text_value) | |
| elif type_id == 3802: # jsonb | |
| import json | |
| return json.loads(text_value) | |
| else: | |
| return text_value | |
| def _read_string(self, data: bytes, offset: int) -> tuple[str, int]: | |
| """Read null-terminated string from data.""" | |
| end = data.find(b"\x00", offset) | |
| if end == -1: | |
| raise ValueError("Null terminator not found") | |
| string_value = data[offset:end].decode("utf-8") | |
| return string_value, end + 1 | |
| @staticmethod | |
| def _pg_timestamp_to_datetime(pg_ts: int) -> datetime: | |
| """Convert PostgreSQL timestamp to Python datetime.""" | |
| pg_epoch = datetime(2000, 1, 1, tzinfo=timezone.utc) | |
| return pg_epoch + timedelta(microseconds=pg_ts) | |
| # ============================================================================ | |
| # High-Level Replication Stream Interface | |
| # ============================================================================ | |
| class LogicalReplicationStream: | |
| """ | |
| Async logical replication consumer using raw socket bridge. | |
| This is the main interface for consuming logical replication changes. | |
| Features: | |
| - Automatic reconnection with exponential backoff | |
| - LSN tracking across reconnections | |
| - Keepalive handling (primary and unsolicited) | |
| - Graceful cleanup | |
| """ | |
| def __init__( | |
| self, | |
| connection_params: Dict[str, Any], | |
| slot_name: str, | |
| publication_name: str, | |
| start_lsn: str = "0/0", | |
| ) -> None: | |
| """ | |
| Initialize replication stream. | |
| Args: | |
| connection_params: Dict with host, port, dbname, user, password | |
| slot_name: Name of the replication slot to use | |
| publication_name: Name of the publication to stream | |
| start_lsn: LSN to start streaming from (default: beginning) | |
| """ | |
| self.connection_params = connection_params | |
| self.slot_name = slot_name | |
| self.publication_name = publication_name | |
| self.start_lsn = start_lsn | |
| self.current_lsn = string_to_lsn(start_lsn) | |
| self.decoder = PgOutputDecoder() | |
| self.bridge: Optional[RawSocketReplicationBridge] = None | |
| self.conn: Optional[AsyncConnection] = None | |
| self._last_feedback_time = 0.0 | |
| self._last_keepalive_time = 0.0 | |
| self._message_count = 0 | |
| async def _create_connection(self) -> AsyncConnection: | |
| """Create a new replication connection.""" | |
| params = self.connection_params | |
| dsn = ( | |
| f"host={params['host']} " | |
| f"port={params['port']} " | |
| f"dbname={params['dbname']} " | |
| f"user={params['user']} " | |
| f"password={params['password']} " | |
| f"replication=database " | |
| f"sslmode=disable" # Raw socket bridge cannot handle TLS | |
| ) | |
| return await AsyncConnection.connect(dsn, autocommit=True) | |
| async def start_replication( | |
| self, | |
| ) -> AsyncIterator[ | |
| tuple[ | |
| ReplicationMessage, | |
| Optional[ | |
| Union[ | |
| BeginMessage, | |
| CommitMessage, | |
| RelationMessage, | |
| InsertMessage, | |
| UpdateMessage, | |
| DeleteMessage, | |
| TruncateMessage, | |
| ] | |
| ], | |
| ] | |
| ]: | |
| """ | |
| Start consuming replication stream with automatic reconnection. | |
| Yields: | |
| Tuples of (ReplicationMessage, decoded_message) | |
| Example: | |
| stream = LogicalReplicationStream( | |
| connection_params={"host": "localhost", "port": 5432, ...}, | |
| slot_name="my_slot", | |
| publication_name="my_pub" | |
| ) | |
| async for raw_msg, decoded_msg in stream.start_replication(): | |
| if isinstance(decoded_msg, InsertMessage): | |
| print(f"INSERT: {decoded_msg.new_tuple}") | |
| await stream.send_feedback(raw_msg.data_end) | |
| """ | |
| retry_delay = 1.0 | |
| max_retry_delay = 60.0 | |
| while True: | |
| try: | |
| # Create connection | |
| self.conn = await self._create_connection() | |
| pgconn = self.conn.pgconn | |
| # Build START_REPLICATION command with current LSN | |
| start_cmd = ( | |
| f"START_REPLICATION SLOT {self.slot_name} LOGICAL {lsn_to_string(self.current_lsn)} " | |
| f"(proto_version '1', publication_names '{self.publication_name}')" | |
| ).encode("utf-8") | |
| # Send command through psycopg | |
| pgconn.send_query(start_cmd) | |
| # Flush output buffer | |
| while True: | |
| result = pgconn.flush() | |
| if result == 0: | |
| break | |
| elif result == -1: | |
| error = pgconn.error_message.decode("utf-8", errors="ignore") | |
| raise Exception(f"Flush error: {error}") | |
| await asyncio.sleep(0.001) | |
| # Wait for response | |
| while pgconn.is_busy(): | |
| pgconn.consume_input() | |
| await asyncio.sleep(0.001) | |
| # Get result | |
| result = pgconn.get_result() | |
| if result is None or result.status != 8: # PGRES_COPY_BOTH | |
| error = pgconn.error_message.decode("utf-8", errors="ignore") | |
| raise Exception(f"Failed to enter COPY_BOTH: {error}") | |
| logger.info( | |
| f"Started replication on slot {self.slot_name} from LSN {lsn_to_string(self.current_lsn)}" | |
| ) | |
| # Hand off to raw socket bridge | |
| self.bridge = RawSocketReplicationBridge(pgconn) | |
| # Reset retry delay on successful connection | |
| retry_delay = 1.0 | |
| # Consume stream | |
| async for raw_msg, decoded_msg in self._read_from_bridge(): | |
| # Update current LSN | |
| self.current_lsn = raw_msg.data_end | |
| self._message_count += 1 | |
| yield raw_msg, decoded_msg | |
| except Exception as e: | |
| logger.error(f"Replication error: {e}. Reconnecting in {retry_delay}s...") | |
| # Clean up | |
| if self.bridge: | |
| self.bridge.close() | |
| self.bridge = None | |
| if self.conn: | |
| await self.conn.close() | |
| self.conn = None | |
| # Wait before retry | |
| await asyncio.sleep(retry_delay) | |
| # Exponential backoff | |
| retry_delay = min(retry_delay * 2, max_retry_delay) | |
| async def _read_from_bridge( | |
| self, | |
| ) -> AsyncIterator[tuple[ReplicationMessage, Optional[Union[BeginMessage, CommitMessage, RelationMessage, InsertMessage, UpdateMessage, DeleteMessage, TruncateMessage]]]]: | |
| """Read replication messages using raw socket bridge.""" | |
| if not self.bridge: | |
| raise Exception("Bridge not initialized") | |
| current_time = asyncio.get_event_loop().time() | |
| self._last_feedback_time = current_time | |
| self._last_keepalive_time = current_time | |
| while True: | |
| current_time = asyncio.get_event_loop().time() | |
| # Send periodic unsolicited keepalives if no messages received | |
| if current_time - self._last_keepalive_time > 5.0: | |
| try: | |
| await self.send_feedback(self.current_lsn, flush=True) | |
| self._last_keepalive_time = current_time | |
| except Exception as e: | |
| logger.error(f"Failed to send keepalive: {e}") | |
| break | |
| # Read message directly from socket | |
| msg_data = await self.bridge.read_copy_data_message() | |
| if msg_data is None: | |
| logger.info("Replication stream ended") | |
| break | |
| # Parse message | |
| if len(msg_data) < 1: | |
| continue | |
| msg_type = chr(msg_data[0]) | |
| if msg_type == "w": # XLogData | |
| if len(msg_data) < 25: | |
| continue | |
| wal_start = struct.unpack(">Q", msg_data[1:9])[0] | |
| wal_end = struct.unpack(">Q", msg_data[9:17])[0] | |
| send_time = struct.unpack(">Q", msg_data[17:25])[0] | |
| payload = msg_data[25:] | |
| raw_msg = ReplicationMessage( | |
| payload=payload, | |
| data_start=wal_start, | |
| data_end=wal_end, | |
| send_time=send_time, | |
| ) | |
| decoded_msg = self.decoder.decode(payload) | |
| # Update keepalive time | |
| self._last_keepalive_time = current_time | |
| yield raw_msg, decoded_msg | |
| elif msg_type == "k": # Primary keepalive | |
| if len(msg_data) < 18: | |
| continue | |
| wal_end = struct.unpack(">Q", msg_data[1:9])[0] | |
| send_time = struct.unpack(">Q", msg_data[9:17])[0] | |
| reply_requested = msg_data[17] == 1 | |
| logger.debug( | |
| f"Received keepalive at LSN {lsn_to_string(wal_end)}, " | |
| f"reply_requested={reply_requested}" | |
| ) | |
| # Update LSN if keepalive advanced it | |
| if wal_end > self.current_lsn: | |
| self.current_lsn = wal_end | |
| # Reply if requested | |
| if reply_requested: | |
| try: | |
| await self.send_feedback(self.current_lsn, flush=True) | |
| except Exception as e: | |
| logger.error(f"Failed to reply to keepalive: {e}") | |
| break | |
| # Update keepalive time | |
| self._last_keepalive_time = current_time | |
| async def send_feedback(self, lsn: int, flush: bool = True) -> None: | |
| """ | |
| Send feedback to acknowledge WAL position. | |
| Args: | |
| lsn: Log Sequence Number to acknowledge | |
| flush: Whether this LSN has been flushed to disk | |
| """ | |
| if not self.bridge: | |
| return | |
| try: | |
| await self.bridge.send_standby_status_update( | |
| received_lsn=lsn, flushed_lsn=lsn if flush else 0, applied_lsn=lsn | |
| ) | |
| self._last_feedback_time = asyncio.get_event_loop().time() | |
| except Exception as e: | |
| logger.error(f"Failed to send feedback: {e}") | |
| raise | |
| async def close(self) -> None: | |
| """Close the replication stream gracefully.""" | |
| if self.bridge: | |
| self.bridge.close() | |
| self.bridge = None | |
| if self.conn: | |
| await self.conn.close() | |
| self.conn = None | |
| # ============================================================================ | |
| # Example Usage | |
| # ============================================================================ | |
| async def example_usage(): | |
| """ | |
| Example of how to use the replication stream. | |
| """ | |
| # Create replication stream | |
| stream = LogicalReplicationStream( | |
| connection_params={ | |
| "host": "localhost", | |
| "port": 5432, | |
| "dbname": "mydb", | |
| "user": "myuser", | |
| "password": "mypass", | |
| }, | |
| slot_name="my_replication_slot", | |
| publication_name="my_publication", | |
| start_lsn="0/0", # Start from beginning | |
| ) | |
| try: | |
| # Consume changes | |
| async for raw_msg, decoded_msg in stream.start_replication(): | |
| # Handle different message types | |
| if isinstance(decoded_msg, InsertMessage): | |
| relation = stream.decoder.relations.get(decoded_msg.relation_id) | |
| if relation: | |
| print( | |
| f"INSERT into {relation.namespace}.{relation.relation_name}: " | |
| f"{decoded_msg.new_tuple}" | |
| ) | |
| elif isinstance(decoded_msg, UpdateMessage): | |
| relation = stream.decoder.relations.get(decoded_msg.relation_id) | |
| if relation: | |
| print( | |
| f"UPDATE {relation.namespace}.{relation.relation_name}: " | |
| f"{decoded_msg.old_tuple} -> {decoded_msg.new_tuple}" | |
| ) | |
| elif isinstance(decoded_msg, DeleteMessage): | |
| relation = stream.decoder.relations.get(decoded_msg.relation_id) | |
| if relation: | |
| print( | |
| f"DELETE from {relation.namespace}.{relation.relation_name}: " | |
| f"{decoded_msg.old_tuple or decoded_msg.key_tuple}" | |
| ) | |
| # Acknowledge the message | |
| await stream.send_feedback(raw_msg.data_end) | |
| finally: | |
| await stream.close() | |
| if __name__ == "__main__": | |
| # Run example | |
| asyncio.run(example_usage()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment