Last active
August 5, 2025 15:58
-
-
Save ArthurDelannoyazerty/102595e4d88d5acbf018b1dad72644de to your computer and use it in GitHub Desktop.
Utils async class that connect to a DB (or create on) and send queries
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import logging | |
from typing import Optional | |
import psycopg | |
from psycopg.sql import SQL, Identifier | |
from psycopg_pool import AsyncConnectionPool # pip install psycopg[binary] psycopg_pool | |
from tqdm.asyncio import tqdm | |
try: | |
from pgvector.psycopg import register_vector_async | |
PGVECTOR_INSTALLED = True | |
except ImportError: | |
PGVECTOR_INSTALLED = False | |
__all__ = ['AsyncInterfaceSQL'] | |
class AsyncInterfaceSQL: | |
""" | |
An awaitable, asynchronous interface for interacting with a PostgreSQL database. | |
syntax: `interface = await InterfaceSQL(...)` | |
""" | |
def __init__(self, | |
database:str, | |
host:str, | |
port:int, | |
user:str, | |
password:str, | |
create_db_if_not_exists:bool=True, | |
min_pool_size: int = 1, | |
max_pool_size: int = 8, | |
use_pgvector: bool = True): | |
self._conninfo = f"dbname={database} host={host} port={port} user={user} password={password}" | |
self._db_name = database | |
self._db_params_no_db = f"host={host} port={port} user={user} password={password}" | |
self._create_db_if_not_exists = create_db_if_not_exists | |
self._min_pool_size = min_pool_size | |
self._max_pool_size = max_pool_size # Max pool connection size | |
self.pool: Optional[AsyncConnectionPool] = None | |
self.use_pgvector = use_pgvector | |
if self.use_pgvector and not PGVECTOR_INSTALLED: | |
logging.warning("`use_pgvector` is True, but the 'pgvector' library is not installed. Disabling feature.") | |
self.use_pgvector = False | |
self._pgvector_extension_exists = PGVECTOR_INSTALLED | |
def __await__(self): | |
return self._async_init().__await__() | |
async def _setup_connection(self, conn: psycopg.AsyncConnection): | |
""" | |
Configures each new connection based on the one-time check. | |
This is now very fast and doesn't run queries. | |
""" | |
# The check is already done. We just apply the result. | |
if self.use_pgvector and self._pgvector_extension_exists: | |
logging.info(f"Connection (ID: {conn.info.backend_pid}): Registering pgvector type adapter.") | |
await register_vector_async(conn) | |
async def _ensure_pgvector_extension(self): | |
""" | |
Uses a single, temporary connection to ensure the pgvector extension exists. | |
This should be called before the main pool is created. | |
""" | |
if not self.use_pgvector: | |
return | |
logging.info("Checking for and ensuring 'vector' extension exists...") | |
try: | |
async with await psycopg.AsyncConnection.connect(self._conninfo, connect_timeout=5) as conn: | |
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector") | |
logging.info("'vector' extension is ready.") | |
except psycopg.Error as e: | |
logging.error(f"Failed to create 'vector' extension: {e}") | |
raise | |
async def _async_init(self) -> 'AsyncInterfaceSQL': | |
"""Performs the asynchronous initialization and returns the configured instance.""" | |
db_ready = False | |
try: | |
# First, try a direct connection to see if the database exists. | |
async with await psycopg.AsyncConnection.connect(self._conninfo, connect_timeout=5): | |
pass | |
logging.info(f"Database '{self._db_name}' already exists.") | |
db_ready = True | |
except psycopg.errors.OperationalError: | |
# The initial connection failed. Assume the database does not exist and try to create it. | |
logging.warning(f"Initial connection to '{self._db_name}' failed. Assuming it doesn't exist.") | |
if self._create_db_if_not_exists: | |
logging.info("Attempting to create the database...") | |
try: | |
# Connect to the default 'postgres' db to issue the CREATE DATABASE command. | |
async with await psycopg.AsyncConnection.connect(f"dbname=postgres {self._db_params_no_db}", autocommit=True, connect_timeout=5) as conn: | |
await conn.execute(f'CREATE DATABASE "{self._db_name}"') | |
logging.info(f"Database '{self._db_name}' created successfully.") | |
db_ready = True | |
except psycopg.Error as create_error: | |
logging.error("Failed to create database. The server is likely running, but another issue occurred (e.g., permissions).") | |
raise create_error | |
else: | |
logging.error("Database does not exist and 'create_db_if_not_exists' is False.") | |
raise | |
# If the database now exists (either initially or after creation), create the pool. | |
if db_ready: | |
await self._ensure_pgvector_extension() | |
logging.info("Initializing connection pool...") | |
self.pool = AsyncConnectionPool(self._conninfo, min_size=self._min_pool_size, max_size=self._max_pool_size, open=False, configure=self._setup_connection) | |
await self.pool.open(wait=True) | |
logging.info(f"Connection pool to '{self._db_name}' is open and ready.") | |
else: | |
raise RuntimeError(f"Failed to connect to or create database '{self._db_name}'.") | |
return self | |
async def send_query(self, query:str, variable:Optional[tuple]=None, fetch:bool=False) -> Optional[list[tuple]]: | |
if not self.pool or self.pool.closed: | |
logging.error("Connection pool is not available or has been closed.") | |
raise ConnectionError("Connection pool is not open. Cannot execute query.") | |
try: | |
async with self.pool.connection() as conn: | |
async with conn.cursor() as cur: | |
await cur.execute(query, variable) | |
if fetch: | |
return await cur.fetchall() | |
except psycopg.Error as e: | |
logging.error(f"Error executing query: {e}") | |
return None | |
async def send_query_bulk(self, query: str, variables: list[tuple], batch_size:int=1000): | |
""" | |
Executes a query for a sequence of parameters. Ideal for bulk inserts. | |
:param query: The query command, e.g., "INSERT INTO users (name, email) VALUES (%s, %s)" | |
:param variables: A list of tuples with the data, e.g., [('John', '[email protected]'), ('Jane', '[email protected]')] | |
""" | |
if not self.pool or self.pool.closed: | |
raise ConnectionError("Connection pool is not available.") | |
total_records = len(variables) | |
logging.info(f"Starting bulk execution for {total_records} records with a batch size of {batch_size}.") | |
with tqdm(total=total_records, desc="Bulk Inserting", unit="row", ncols=100) as pbar: | |
try: | |
async with self.pool.connection() as conn: | |
async with conn.cursor() as cur: | |
for i in range(0, total_records, batch_size): | |
start_index = i | |
end_index = start_index + batch_size | |
batch = variables[start_index:end_index] | |
logging.info(f"Processing batch: records {start_index+1} to {min(end_index, total_records)} of {total_records}...") | |
await cur.executemany(query, batch) | |
pbar.update(len(batch)) | |
except psycopg.Error as e: | |
logging.error(f"Error in execute_many: {e}") | |
async def copy_records(self, table_name: str, columns: tuple[str], records: list[tuple]): | |
""" | |
Uses the high-performance COPY command to bulk insert records from a list of tuples. | |
This is the fastest method for bulk data loading into PostgreSQL. | |
:param table_name: The name of the table to insert into. | |
:param columns: A tuple of column names, e.g., ('name', 'email'). | |
:param records: A list of tuples where each tuple is a row to insert. | |
""" | |
if not self.pool or self.pool.closed: | |
raise ConnectionError("Connection pool is not available.") | |
copy_query = SQL("COPY {table} ({cols}) FROM STDIN WITH (FORMAT BINARY)").format( | |
table=Identifier(table_name), | |
cols=SQL(', ').join(map(Identifier, columns)) | |
) | |
record_count = len(records) | |
logging.info(f"Starting COPY for {record_count:,} records into '{table_name}'.") | |
try: | |
async with self.pool.connection() as conn: | |
async with conn.cursor() as cur: | |
async with cur.copy(copy_query) as copy: | |
for record in tqdm(records, desc="Streaming to DB", unit="row", ncols=100): | |
await copy.write_row(record) | |
logging.info(f"Successfully copied {record_count:,} records.") | |
except psycopg.Error as e: | |
logging.error(f"Error in COPY operation: {e}") | |
raise | |
async def close(self): | |
if self.pool and not self.pool.closed: | |
await self.pool.close() | |
logging.info("Database connection pool closed.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.