Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save dvarrazzo/ad8eea3f4476690e3a42e28171b2898b to your computer and use it in GitHub Desktop.
Save dvarrazzo/ad8eea3f4476690e3a42e28171b2898b to your computer and use it in GitHub Desktop.
A production-ready implementation of PostgreSQL logical replication using psycopg3. Works around psycopg3's missing replication support by using a raw socket bridge.
"""
PostgreSQL Logical Replication for psycopg3
A production-ready implementation of PostgreSQL logical replication using psycopg3.
Works around psycopg3's missing replication support by using a raw socket bridge.
Author: Richard Brandes
Date: October 2025
License: MIT
Related: https://github.com/psycopg/psycopg/issues/71
"""
import asyncio
import logging
import socket as socket_module
import struct
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from enum import Enum
from typing import Any, AsyncIterator, Dict, List, Optional, Union
from psycopg import AsyncConnection
logger = logging.getLogger(__name__)
# ============================================================================
# Protocol Message Types and Data Structures
# ============================================================================
class MessageType(Enum):
"""pgoutput message types from PostgreSQL logical replication protocol."""
BEGIN = "B"
COMMIT = "C"
ORIGIN = "O"
RELATION = "R"
TYPE = "Y"
INSERT = "I"
UPDATE = "U"
DELETE = "D"
TRUNCATE = "T"
MESSAGE = "M"
STREAM_START = "S"
STREAM_STOP = "E"
STREAM_COMMIT = "c"
STREAM_ABORT = "A"
@dataclass
class ReplicationMessage:
"""Base replication message from WAL stream."""
payload: bytes
data_start: int # LSN where this data starts
data_end: int # LSN where this data ends
send_time: int # Server send timestamp
@dataclass
class ColumnDefinition:
"""Column metadata from Relation message."""
flags: int
name: str
type_id: int
type_modifier: int
@property
def is_key(self) -> bool:
"""Check if column is part of primary key."""
return bool(self.flags & 0x01)
@dataclass
class RelationMessage:
"""Relation (table schema) message."""
relation_id: int
namespace: str
relation_name: str
replica_identity: int
columns: List[ColumnDefinition]
@dataclass
class BeginMessage:
"""Transaction begin message."""
final_lsn: int
commit_ts: datetime
xid: int
@dataclass
class CommitMessage:
"""Transaction commit message."""
flags: int
commit_lsn: int
end_lsn: int
commit_ts: datetime
@dataclass
class InsertMessage:
"""Insert operation message."""
relation_id: int
new_tuple: Dict[str, Any]
@dataclass
class UpdateMessage:
"""Update operation message."""
relation_id: int
old_tuple: Optional[Dict[str, Any]]
new_tuple: Dict[str, Any]
@dataclass
class DeleteMessage:
"""Delete operation message."""
relation_id: int
old_tuple: Optional[Dict[str, Any]]
key_tuple: Optional[Dict[str, Any]]
@dataclass
class TruncateMessage:
"""Truncate operation message."""
options: int
relation_ids: List[int]
# ============================================================================
# Raw Socket Bridge - Direct PostgreSQL Wire Protocol Access
# ============================================================================
class RawSocketReplicationBridge:
"""
Direct socket bridge for reading PostgreSQL replication data.
This completely bypasses psycopg's COPY handling after connection setup,
reading directly from the socket to avoid psycopg's incomplete replication support.
"""
def __init__(self, pgconn: Any) -> None:
"""
Initialize bridge with libpq connection.
Args:
pgconn: The libpq PGconn object from psycopg (typed as Any due to protocol limitation)
"""
self.pgconn: Any = pgconn
self.socket_fd: int = int(pgconn.socket)
self.socket_obj = socket_module.socket(fileno=self.socket_fd)
self.socket_obj.setblocking(False)
self.read_buffer = bytearray()
async def read_copy_data_message(self) -> Optional[bytes]:
"""
Read one complete CopyData message directly from socket.
Returns:
Message payload or None if stream ended
"""
while True:
# Ensure we have at least the message header (1 byte type + 4 bytes length)
while len(self.read_buffer) < 5:
try:
chunk = await asyncio.get_event_loop().sock_recv(
self.socket_obj, 8192
)
if not chunk:
return None
self.read_buffer.extend(chunk)
except BlockingIOError:
await asyncio.sleep(0.01)
continue
except Exception as e:
raise Exception(f"Socket read error: {e}")
# Parse message header
msg_type = chr(self.read_buffer[0])
msg_length = struct.unpack(">I", self.read_buffer[1:5])[0]
# Message length includes itself (4 bytes) but not the type byte
total_msg_size = 1 + msg_length
# Read until we have the complete message
while len(self.read_buffer) < total_msg_size:
try:
chunk = await asyncio.get_event_loop().sock_recv(
self.socket_obj, 8192
)
if not chunk:
return None
self.read_buffer.extend(chunk)
except BlockingIOError:
await asyncio.sleep(0.01)
continue
# Extract complete message
complete_message = bytes(self.read_buffer[:total_msg_size])
self.read_buffer = self.read_buffer[total_msg_size:]
# Handle different message types
if msg_type == "d": # CopyData - return payload
return complete_message[5:]
elif msg_type == "c": # CopyDone - end of stream
logger.info("Received CopyDone message")
return None
elif msg_type == "f": # CopyFail - error
raise Exception("Server sent CopyFail message")
elif msg_type == "H": # CopyBothResponse - initial response
logger.debug("Received CopyBothResponse")
continue
else:
logger.debug(f"Skipping unknown message type: {msg_type}")
continue
async def send_standby_status_update(
self, received_lsn: int, flushed_lsn: int, applied_lsn: int
) -> None:
"""
Send standby status update message directly to socket.
This acknowledges WAL positions to PostgreSQL.
"""
# Build standby status update payload
status_msg = struct.pack(
">cQQQQB",
b"r", # Standby status update
received_lsn, # Last WAL received
flushed_lsn, # Last WAL flushed
applied_lsn, # Last WAL applied
0, # Current timestamp (0 = server time)
0, # Reply requested (0 = no)
)
# Wrap in CopyData message (type 'd' + length + payload)
copy_data_msg = struct.pack(">cI", b"d", len(status_msg) + 4) + status_msg
try:
await asyncio.get_event_loop().sock_sendall(
self.socket_obj, copy_data_msg
)
except Exception as e:
logger.error(f"Failed to send feedback: {e}")
# ============================================================================
# pgoutput Binary Protocol Decoder
# ============================================================================
class PgOutputDecoder:
"""
Decoder for PostgreSQL pgoutput binary protocol.
Based on: https://www.postgresql.org/docs/current/protocol-logicalrep-message-formats.html
"""
def __init__(self) -> None:
self.relations: Dict[int, RelationMessage] = {}
self.types: Dict[int, str] = {}
def decode(
self, data: bytes
) -> Optional[
Union[
BeginMessage,
CommitMessage,
RelationMessage,
InsertMessage,
UpdateMessage,
DeleteMessage,
TruncateMessage,
]
]:
"""Decode a pgoutput binary message."""
if len(data) == 0:
return None
msg_type = chr(data[0])
try:
if msg_type == MessageType.BEGIN.value:
return self._decode_begin(data[1:])
elif msg_type == MessageType.COMMIT.value:
return self._decode_commit(data[1:])
elif msg_type == MessageType.RELATION.value:
return self._decode_relation(data[1:])
elif msg_type == MessageType.INSERT.value:
return self._decode_insert(data[1:])
elif msg_type == MessageType.UPDATE.value:
return self._decode_update(data[1:])
elif msg_type == MessageType.DELETE.value:
return self._decode_delete(data[1:])
elif msg_type == MessageType.TRUNCATE.value:
return self._decode_truncate(data[1:])
elif msg_type == MessageType.TYPE.value:
return None
elif msg_type == MessageType.ORIGIN.value:
return None
else:
logger.debug(f"Unhandled message type: {msg_type}")
return None
except Exception as e:
logger.error(f"Error decoding message type {msg_type}: {e}", exc_info=True)
return None
def _decode_begin(self, data: bytes) -> BeginMessage:
"""Decode BEGIN message."""
offset = 0
final_lsn = struct.unpack(">Q", data[offset : offset + 8])[0]
offset += 8
commit_ts_raw = struct.unpack(">Q", data[offset : offset + 8])[0]
offset += 8
commit_ts = self._pg_timestamp_to_datetime(commit_ts_raw)
xid = struct.unpack(">I", data[offset : offset + 4])[0]
return BeginMessage(final_lsn=final_lsn, commit_ts=commit_ts, xid=xid)
def _decode_commit(self, data: bytes) -> CommitMessage:
"""Decode COMMIT message."""
offset = 0
flags = data[offset]
offset += 1
commit_lsn = struct.unpack(">Q", data[offset : offset + 8])[0]
offset += 8
end_lsn = struct.unpack(">Q", data[offset : offset + 8])[0]
offset += 8
commit_ts_raw = struct.unpack(">Q", data[offset : offset + 8])[0]
commit_ts = self._pg_timestamp_to_datetime(commit_ts_raw)
return CommitMessage(
flags=flags, commit_lsn=commit_lsn, end_lsn=end_lsn, commit_ts=commit_ts
)
def _decode_relation(self, data: bytes) -> RelationMessage:
"""Decode RELATION message."""
offset = 0
relation_id = struct.unpack(">I", data[offset : offset + 4])[0]
offset += 4
namespace, offset = self._read_string(data, offset)
relation_name, offset = self._read_string(data, offset)
replica_identity = data[offset]
offset += 1
num_columns = struct.unpack(">H", data[offset : offset + 2])[0]
offset += 2
columns: List[ColumnDefinition] = []
for _ in range(num_columns):
flags = data[offset]
offset += 1
col_name, offset = self._read_string(data, offset)
type_id = struct.unpack(">I", data[offset : offset + 4])[0]
offset += 4
type_mod = struct.unpack(">i", data[offset : offset + 4])[0]
offset += 4
columns.append(
ColumnDefinition(
flags=flags, name=col_name, type_id=type_id, type_modifier=type_mod
)
)
relation = RelationMessage(
relation_id=relation_id,
namespace=namespace,
relation_name=relation_name,
replica_identity=replica_identity,
columns=columns,
)
self.relations[relation_id] = relation
return relation
def _decode_insert(self, data: bytes) -> InsertMessage:
"""Decode INSERT message."""
offset = 0
relation_id = struct.unpack(">I", data[offset : offset + 4])[0]
offset += 4
tuple_kind = chr(data[offset])
offset += 1
if tuple_kind != "N":
raise ValueError(f"Unexpected tuple kind in INSERT: {tuple_kind}")
new_tuple, _ = self._decode_tuple_data(data, offset, relation_id)
return InsertMessage(relation_id=relation_id, new_tuple=new_tuple)
def _decode_update(self, data: bytes) -> UpdateMessage:
"""Decode UPDATE message."""
offset = 0
relation_id = struct.unpack(">I", data[offset : offset + 4])[0]
offset += 4
old_tuple: Optional[Dict[str, Any]] = None
tuple_kind = chr(data[offset])
offset += 1
if tuple_kind == "O":
old_tuple, offset = self._decode_tuple_data(data, offset, relation_id)
tuple_kind = chr(data[offset])
offset += 1
elif tuple_kind == "K":
old_tuple, offset = self._decode_tuple_data(data, offset, relation_id)
tuple_kind = chr(data[offset])
offset += 1
if tuple_kind != "N":
raise ValueError(f"Expected 'N' tuple kind, got: {tuple_kind}")
new_tuple, _ = self._decode_tuple_data(data, offset, relation_id)
return UpdateMessage(
relation_id=relation_id, old_tuple=old_tuple, new_tuple=new_tuple
)
def _decode_delete(self, data: bytes) -> DeleteMessage:
"""Decode DELETE message."""
offset = 0
relation_id = struct.unpack(">I", data[offset : offset + 4])[0]
offset += 4
tuple_kind = chr(data[offset])
offset += 1
old_tuple: Optional[Dict[str, Any]] = None
key_tuple: Optional[Dict[str, Any]] = None
if tuple_kind == "O":
old_tuple, _ = self._decode_tuple_data(data, offset, relation_id)
elif tuple_kind == "K":
key_tuple, _ = self._decode_tuple_data(data, offset, relation_id)
return DeleteMessage(
relation_id=relation_id, old_tuple=old_tuple, key_tuple=key_tuple
)
def _decode_truncate(self, data: bytes) -> TruncateMessage:
"""Decode TRUNCATE message."""
offset = 0
num_relations = struct.unpack(">I", data[offset : offset + 4])[0]
offset += 4
options = data[offset]
offset += 1
relation_ids: List[int] = []
for _ in range(num_relations):
rel_id = struct.unpack(">I", data[offset : offset + 4])[0]
offset += 4
relation_ids.append(rel_id)
return TruncateMessage(options=options, relation_ids=relation_ids)
def _decode_tuple_data(
self, data: bytes, offset: int, relation_id: int
) -> tuple[Dict[str, Any], int]:
"""Decode tuple data from INSERT/UPDATE/DELETE messages."""
relation = self.relations.get(relation_id)
if not relation:
raise ValueError(f"Unknown relation ID: {relation_id}")
num_columns = struct.unpack(">H", data[offset : offset + 2])[0]
offset += 2
if num_columns != len(relation.columns):
raise ValueError(
f"Column count mismatch: expected {len(relation.columns)}, got {num_columns}"
)
tuple_data: Dict[str, Any] = {}
for col_def in relation.columns:
col_type = chr(data[offset])
offset += 1
if col_type == "n":
tuple_data[col_def.name] = None
elif col_type == "u":
tuple_data[col_def.name] = None
elif col_type == "t":
col_len = struct.unpack(">I", data[offset : offset + 4])[0]
offset += 4
col_value = data[offset : offset + col_len]
offset += col_len
tuple_data[col_def.name] = self._decode_value(
col_value, col_def.type_id
)
else:
raise ValueError(f"Unknown column data type: {col_type}")
return tuple_data, offset
def _decode_value(self, value: bytes, type_id: int) -> Any:
"""Decode column value based on PostgreSQL type OID."""
text_value = value.decode("utf-8")
if type_id == 16: # bool
return text_value == "t"
elif type_id in (20, 21, 23): # int8, int2, int4
return int(text_value)
elif type_id in (700, 701): # float4, float8
return float(text_value)
elif type_id in (1082,): # date
return text_value
elif type_id in (1114, 1184): # timestamp, timestamptz
return text_value
elif type_id == 2950: # uuid
return text_value
elif type_id == 114: # json
import json
return json.loads(text_value)
elif type_id == 3802: # jsonb
import json
return json.loads(text_value)
else:
return text_value
def _read_string(self, data: bytes, offset: int) -> tuple[str, int]:
"""Read null-terminated string from data."""
end = data.find(b"\x00", offset)
if end == -1:
raise ValueError("Null terminator not found")
string_value = data[offset:end].decode("utf-8")
return string_value, end + 1
@staticmethod
def _pg_timestamp_to_datetime(pg_ts: int) -> datetime:
"""Convert PostgreSQL timestamp to Python datetime."""
pg_epoch = datetime(2000, 1, 1, tzinfo=timezone.utc)
return pg_epoch + timedelta(microseconds=pg_ts)
# ============================================================================
# High-Level Replication Stream Interface
# ============================================================================
class LogicalReplicationStream:
"""
Async logical replication consumer using raw socket bridge.
This is the main interface for consuming logical replication changes.
"""
def __init__(
self,
conn: AsyncConnection,
slot_name: str,
publication_name: str,
start_lsn: str = "0/0",
) -> None:
"""
Initialize replication stream.
Args:
conn: psycopg AsyncConnection with replication=database
slot_name: Name of the replication slot to use
publication_name: Name of the publication to stream
start_lsn: LSN to start streaming from (default: beginning)
"""
self.conn = conn
self.slot_name = slot_name
self.publication_name = publication_name
self.start_lsn = start_lsn
self.decoder = PgOutputDecoder()
self.bridge: Optional[RawSocketReplicationBridge] = None
async def start_replication(
self,
) -> AsyncIterator[
tuple[
ReplicationMessage,
Optional[
Union[
BeginMessage,
CommitMessage,
RelationMessage,
InsertMessage,
UpdateMessage,
DeleteMessage,
TruncateMessage,
]
],
]
]:
"""
Start consuming replication stream.
Yields:
Tuples of (ReplicationMessage, decoded_message)
Example:
async for raw_msg, decoded_msg in stream.start_replication():
if isinstance(decoded_msg, InsertMessage):
print(f"INSERT into {decoded_msg.relation_id}: {decoded_msg.new_tuple}")
await stream.send_feedback(raw_msg.data_end)
"""
pgconn = self.conn.pgconn
# Build START_REPLICATION command
start_cmd = (
f"START_REPLICATION SLOT {self.slot_name} LOGICAL {self.start_lsn} "
f"(proto_version '1', publication_names '{self.publication_name}')"
).encode("utf-8")
# Send command through psycopg
pgconn.send_query(start_cmd)
# Flush output buffer
while True:
result = pgconn.flush()
if result == 0:
break
elif result == -1:
error = pgconn.error_message.decode("utf-8", errors="ignore")
raise Exception(f"Flush error: {error}")
await asyncio.sleep(0.001)
# Wait for response
while pgconn.is_busy():
pgconn.consume_input()
await asyncio.sleep(0.001)
# Get result
result = pgconn.get_result()
if result is None or result.status != 8: # PGRES_COPY_BOTH
error = pgconn.error_message.decode("utf-8", errors="ignore")
raise Exception(f"Failed to enter COPY_BOTH: {error}")
logger.info(f"Started replication on slot {self.slot_name} (COPY_BOTH mode)")
# Hand off to raw socket bridge - psycopg no longer involved
self.bridge = RawSocketReplicationBridge(pgconn)
try:
async for raw_msg in self._read_from_bridge():
decoded_msg = self.decoder.decode(raw_msg.payload)
yield raw_msg, decoded_msg
finally:
self.bridge = None
async def _read_from_bridge(self) -> AsyncIterator[ReplicationMessage]:
"""Read replication messages using raw socket bridge."""
if not self.bridge:
raise Exception("Bridge not initialized")
while True:
# Read message directly from socket - bypassing psycopg completely
msg_data = await self.bridge.read_copy_data_message()
if msg_data is None:
logger.info("Replication stream ended")
break
# Parse XLogData wrapper
if len(msg_data) < 1:
continue
msg_type = chr(msg_data[0])
if msg_type == "w": # XLogData
if len(msg_data) < 25:
continue
wal_start = struct.unpack(">Q", msg_data[1:9])[0]
wal_end = struct.unpack(">Q", msg_data[9:17])[0]
send_time = struct.unpack(">Q", msg_data[17:25])[0]
payload = msg_data[25:]
yield ReplicationMessage(
payload=payload,
data_start=wal_start,
data_end=wal_end,
send_time=send_time,
)
elif msg_type == "k": # Keepalive
logger.debug("Received keepalive")
async def send_feedback(self, lsn: int, flush: bool = True) -> None:
"""
Send feedback to acknowledge WAL position.
Args:
lsn: Log Sequence Number to acknowledge
flush: Whether this LSN has been flushed to disk
"""
if not self.bridge:
return
try:
await self.bridge.send_standby_status_update(
received_lsn=lsn, flushed_lsn=lsn if flush else 0, applied_lsn=lsn
)
except Exception as e:
logger.error(f"Failed to send feedback: {e}")
# ============================================================================
# Example Usage
# ============================================================================
async def example_usage():
"""
Example of how to use the replication stream.
"""
# Create connection with replication=database
dsn = "host=localhost port=5432 dbname=mydb user=myuser password=mypass replication=database"
conn = await AsyncConnection.connect(dsn, autocommit=True)
try:
# Create replication stream
stream = LogicalReplicationStream(
conn=conn,
slot_name="my_replication_slot",
publication_name="my_publication",
start_lsn="0/0", # Start from beginning
)
# Consume changes
async for raw_msg, decoded_msg in stream.start_replication():
# Handle different message types
if isinstance(decoded_msg, InsertMessage):
relation = stream.decoder.relations.get(decoded_msg.relation_id)
if relation:
print(
f"INSERT into {relation.namespace}.{relation.relation_name}: "
f"{decoded_msg.new_tuple}"
)
elif isinstance(decoded_msg, UpdateMessage):
relation = stream.decoder.relations.get(decoded_msg.relation_id)
if relation:
print(
f"UPDATE {relation.namespace}.{relation.relation_name}: "
f"{decoded_msg.old_tuple} -> {decoded_msg.new_tuple}"
)
elif isinstance(decoded_msg, DeleteMessage):
relation = stream.decoder.relations.get(decoded_msg.relation_id)
if relation:
print(
f"DELETE from {relation.namespace}.{relation.relation_name}: "
f"{decoded_msg.old_tuple or decoded_msg.key_tuple}"
)
# Acknowledge the message
await stream.send_feedback(raw_msg.data_end)
finally:
await conn.close()
if __name__ == "__main__":
# Run example
asyncio.run(example_usage())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment