Skip to content

Instantly share code, notes, and snippets.

@ArthurDelannoyazerty
Last active August 5, 2025 15:58
Show Gist options
  • Save ArthurDelannoyazerty/102595e4d88d5acbf018b1dad72644de to your computer and use it in GitHub Desktop.
Save ArthurDelannoyazerty/102595e4d88d5acbf018b1dad72644de to your computer and use it in GitHub Desktop.
Utils async class that connect to a DB (or create on) and send queries
import asyncio
import logging
from typing import Optional
import psycopg
from psycopg.sql import SQL, Identifier
from psycopg_pool import AsyncConnectionPool # pip install psycopg[binary] psycopg_pool
from tqdm.asyncio import tqdm
try:
from pgvector.psycopg import register_vector_async
PGVECTOR_INSTALLED = True
except ImportError:
PGVECTOR_INSTALLED = False
__all__ = ['AsyncInterfaceSQL']
class AsyncInterfaceSQL:
"""
An awaitable, asynchronous interface for interacting with a PostgreSQL database.
syntax: `interface = await InterfaceSQL(...)`
"""
def __init__(self,
database:str,
host:str,
port:int,
user:str,
password:str,
create_db_if_not_exists:bool=True,
min_pool_size: int = 1,
max_pool_size: int = 8,
use_pgvector: bool = True):
self._conninfo = f"dbname={database} host={host} port={port} user={user} password={password}"
self._db_name = database
self._db_params_no_db = f"host={host} port={port} user={user} password={password}"
self._create_db_if_not_exists = create_db_if_not_exists
self._min_pool_size = min_pool_size
self._max_pool_size = max_pool_size # Max pool connection size
self.pool: Optional[AsyncConnectionPool] = None
self.use_pgvector = use_pgvector
if self.use_pgvector and not PGVECTOR_INSTALLED:
logging.warning("`use_pgvector` is True, but the 'pgvector' library is not installed. Disabling feature.")
self.use_pgvector = False
self._pgvector_extension_exists = PGVECTOR_INSTALLED
def __await__(self):
return self._async_init().__await__()
async def _setup_connection(self, conn: psycopg.AsyncConnection):
"""
Configures each new connection based on the one-time check.
This is now very fast and doesn't run queries.
"""
# The check is already done. We just apply the result.
if self.use_pgvector and self._pgvector_extension_exists:
logging.info(f"Connection (ID: {conn.info.backend_pid}): Registering pgvector type adapter.")
await register_vector_async(conn)
async def _ensure_pgvector_extension(self):
"""
Uses a single, temporary connection to ensure the pgvector extension exists.
This should be called before the main pool is created.
"""
if not self.use_pgvector:
return
logging.info("Checking for and ensuring 'vector' extension exists...")
try:
async with await psycopg.AsyncConnection.connect(self._conninfo, connect_timeout=5) as conn:
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
logging.info("'vector' extension is ready.")
except psycopg.Error as e:
logging.error(f"Failed to create 'vector' extension: {e}")
raise
async def _async_init(self) -> 'AsyncInterfaceSQL':
"""Performs the asynchronous initialization and returns the configured instance."""
db_ready = False
try:
# First, try a direct connection to see if the database exists.
async with await psycopg.AsyncConnection.connect(self._conninfo, connect_timeout=5):
pass
logging.info(f"Database '{self._db_name}' already exists.")
db_ready = True
except psycopg.errors.OperationalError:
# The initial connection failed. Assume the database does not exist and try to create it.
logging.warning(f"Initial connection to '{self._db_name}' failed. Assuming it doesn't exist.")
if self._create_db_if_not_exists:
logging.info("Attempting to create the database...")
try:
# Connect to the default 'postgres' db to issue the CREATE DATABASE command.
async with await psycopg.AsyncConnection.connect(f"dbname=postgres {self._db_params_no_db}", autocommit=True, connect_timeout=5) as conn:
await conn.execute(f'CREATE DATABASE "{self._db_name}"')
logging.info(f"Database '{self._db_name}' created successfully.")
db_ready = True
except psycopg.Error as create_error:
logging.error("Failed to create database. The server is likely running, but another issue occurred (e.g., permissions).")
raise create_error
else:
logging.error("Database does not exist and 'create_db_if_not_exists' is False.")
raise
# If the database now exists (either initially or after creation), create the pool.
if db_ready:
await self._ensure_pgvector_extension()
logging.info("Initializing connection pool...")
self.pool = AsyncConnectionPool(self._conninfo, min_size=self._min_pool_size, max_size=self._max_pool_size, open=False, configure=self._setup_connection)
await self.pool.open(wait=True)
logging.info(f"Connection pool to '{self._db_name}' is open and ready.")
else:
raise RuntimeError(f"Failed to connect to or create database '{self._db_name}'.")
return self
async def send_query(self, query:str, variable:Optional[tuple]=None, fetch:bool=False) -> Optional[list[tuple]]:
if not self.pool or self.pool.closed:
logging.error("Connection pool is not available or has been closed.")
raise ConnectionError("Connection pool is not open. Cannot execute query.")
try:
async with self.pool.connection() as conn:
async with conn.cursor() as cur:
await cur.execute(query, variable)
if fetch:
return await cur.fetchall()
except psycopg.Error as e:
logging.error(f"Error executing query: {e}")
return None
async def send_query_bulk(self, query: str, variables: list[tuple], batch_size:int=1000):
"""
Executes a query for a sequence of parameters. Ideal for bulk inserts.
:param query: The query command, e.g., "INSERT INTO users (name, email) VALUES (%s, %s)"
:param variables: A list of tuples with the data, e.g., [('John', '[email protected]'), ('Jane', '[email protected]')]
"""
if not self.pool or self.pool.closed:
raise ConnectionError("Connection pool is not available.")
total_records = len(variables)
logging.info(f"Starting bulk execution for {total_records} records with a batch size of {batch_size}.")
with tqdm(total=total_records, desc="Bulk Inserting", unit="row", ncols=100) as pbar:
try:
async with self.pool.connection() as conn:
async with conn.cursor() as cur:
for i in range(0, total_records, batch_size):
start_index = i
end_index = start_index + batch_size
batch = variables[start_index:end_index]
logging.info(f"Processing batch: records {start_index+1} to {min(end_index, total_records)} of {total_records}...")
await cur.executemany(query, batch)
pbar.update(len(batch))
except psycopg.Error as e:
logging.error(f"Error in execute_many: {e}")
async def copy_records(self, table_name: str, columns: tuple[str], records: list[tuple]):
"""
Uses the high-performance COPY command to bulk insert records from a list of tuples.
This is the fastest method for bulk data loading into PostgreSQL.
:param table_name: The name of the table to insert into.
:param columns: A tuple of column names, e.g., ('name', 'email').
:param records: A list of tuples where each tuple is a row to insert.
"""
if not self.pool or self.pool.closed:
raise ConnectionError("Connection pool is not available.")
copy_query = SQL("COPY {table} ({cols}) FROM STDIN WITH (FORMAT BINARY)").format(
table=Identifier(table_name),
cols=SQL(', ').join(map(Identifier, columns))
)
record_count = len(records)
logging.info(f"Starting COPY for {record_count:,} records into '{table_name}'.")
try:
async with self.pool.connection() as conn:
async with conn.cursor() as cur:
async with cur.copy(copy_query) as copy:
for record in tqdm(records, desc="Streaming to DB", unit="row", ncols=100):
await copy.write_row(record)
logging.info(f"Successfully copied {record_count:,} records.")
except psycopg.Error as e:
logging.error(f"Error in COPY operation: {e}")
raise
async def close(self):
if self.pool and not self.pool.closed:
await self.pool.close()
logging.info("Database connection pool closed.")
@ArthurDelannoyazerty
Copy link
Author

ArthurDelannoyazerty commented Aug 4, 2025

async def main():
    import numpy as np 

    # --- Database credentials ---
    DB_NAME = "test-db"
    DB_USER = "user"
    DB_PASSWORD = "password"
    DB_HOST = "127.0.0.1"
    DB_PORT = 5432
    MAX_POOL_SIZE = 10

    interface = None
    try:
        interface = await AsyncInterfaceSQL(database=DB_NAME, host=DB_HOST, port=DB_PORT, user=DB_USER, password=DB_PASSWORD, max_pool_size=MAX_POOL_SIZE, use_pgvector=True)

        await interface.send_query("CREATE TABLE IF NOT EXISTS users (id SERIAL PRIMARY KEY, name VARCHAR(100), email VARCHAR(100))")
        logging.info("Table 'users' created or already exists.")

        await interface.send_query(
            "INSERT INTO users (name, email) VALUES (%s, %s)",
            ("Final Test", "[email protected]")
        )
        logging.info("Inserted data.")

        users = await interface.send_query("SELECT * FROM users", fetch=True)
        if users:
            logging.info(f"Selected users: {users}")
        

        
        logging.info("Inserting bulk data.")
        new_users = [(f"user_{i}", f"user{i}@example.com") for i in range(10000)]
        # Call the new method
        await interface.send_query_bulk(
            "INSERT INTO users (name, email) VALUES (%s, %s)",
            new_users
        )
        logging.info("Bulk data inserted")

        logging.info("Test copy_records")
        await interface.send_query("DROP TABLE IF EXISTS products")
        await interface.send_query("""                          
            CREATE EXTENSION IF NOT EXISTS vector;   
            CREATE TABLE products (
                id SERIAL PRIMARY KEY,
                product_name VARCHAR(255),
                category VARCHAR(100),
                price DOUBLE PRECISION,
                embedding VECTOR(1024)
            )
        """)
        
        logging.info("Table 'products' created.")
        PRODUCT_COUNT = 100_000
        PRODUCT_COLUMNS = ('product_name', 'category', 'price', 'embedding')
        logging.info(f"Generating {PRODUCT_COUNT:,} records in memory...")
        product_data = [
            (f"Product #{i}", f"Category {(i % 100)}", i * 1.5, np.random.rand(1024).astype(np.float32))
            for i in tqdm(range(PRODUCT_COUNT), desc="Generating Data", ncols=100)
        ]
        
        await interface.copy_records(table_name='products', columns=PRODUCT_COLUMNS, records=product_data)
        count_result = await interface.send_query("SELECT COUNT(*) FROM products", fetch=True)
        if count_result:
            logging.info(f"Verification: Found {count_result[0][0]:,} rows in 'products' table.")

    except (psycopg.Error, ConnectionRefusedError) as e:
        logging.error(f"A database error occurred: {e}")
        logging.error("Please ensure PostgreSQL is running and credentials are correct.")
    finally:
        if interface:
            await interface.close()


if __name__ == "__main__":
    asyncio.run(main())

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment