Skip to content

Instantly share code, notes, and snippets.

@jdnichollsc
Last active March 26, 2025 15:50
Show Gist options
  • Save jdnichollsc/1301bddd3b958b4f465063d700193212 to your computer and use it in GitHub Desktop.
Save jdnichollsc/1301bddd3b958b4f465063d700193212 to your computer and use it in GitHub Desktop.
Data ingestion with Hugging Face datasets
"""
Fast image ingestion tool for Wolt Food dataset.
This script efficiently processes the Wolt Food CLIP-ViT-B-32 dataset,
uploading images to S3 and storing metadata and embeddings in PostgreSQL.
Optimized for processing 100k+ records using:
- Parallel image downloads and S3 uploads with connection pooling
- PostgreSQL COPY for bulk inserts
- Datasets from Hugging Face:
- Datasets: https://qdrant.tech/documentation/datasets/
- Embeddings: https://huggingface.co/datasets/Qdrant/wolt-food-clip-ViT-B-32-embeddings
- SentenceTransformers: https://www.sbert.net/examples/applications/image-search/README.html
"""
import os
import boto3
import psycopg2
import json
import logging
import aiohttp
import traceback
import time
import psutil
import asyncio
from io import StringIO
from datasets import load_dataset
from botocore.config import Config
from typing import Dict, List, Any, Optional, Tuple
import pandas as pd
from datetime import datetime, timezone
from psycopg2.pool import SimpleConnectionPool
from psycopg2.extensions import ISOLATION_LEVEL_READ_COMMITTED
from tqdm import tqdm
# Configure structured JSON logging
logging.basicConfig(
level=logging.INFO,
format='{"timestamp": "%(asctime)s", "level": "%(levelname)s", "message": %(message)s}',
force=True,
)
logger = logging.getLogger(__name__)
def get_optimal_batch_size():
"""Calculate optimal batch size based on available system memory."""
available_memory = psutil.virtual_memory().available
# Estimate 2MB per image on average
estimated_image_size = 2 * 1024 * 1024 # 2MB in bytes
# Use 40% of available memory
memory_limit = available_memory * 0.4
optimal_size = int(memory_limit / estimated_image_size)
# Set reasonable limits
return max(500, min(optimal_size, 2000))
# Maximum number of records to process from the dataset
MAX_RECORDS = 1000 # Adjust this value to limit the number of records to process
# Performance tuning constants
BATCH_SIZE = get_optimal_batch_size()
MAX_RETRIES = 2
# Database performance settings
DB_MAX_CONNECTIONS = min(100, psutil.cpu_count() * 4) # Scale with CPUs
logger.info(
{
"event": "performance_config",
"batch_size": BATCH_SIZE,
"available_memory": psutil.virtual_memory().available
/ (1024 * 1024 * 1024), # GB
"cpu_cores": psutil.cpu_count(),
}
)
# Environment and configuration
IS_DEVELOPMENT = os.getenv("NODE_ENV") != "production"
BUCKET_NAME = (
os.getenv("MINIO_BUCKET_NAME", "projectx-images")
if IS_DEVELOPMENT
else os.getenv("AWS_BUCKET_NAME")
)
# Database configuration
DB_CONFIG = {
"dbname": os.getenv("POSTGRES_DB", "projectx"),
"user": os.getenv("POSTGRES_USER", "postgres"),
"password": os.getenv("POSTGRES_PASSWORD", "postgres"),
"host": "localhost",
"port": os.getenv("POSTGRES_PORT", "5432"),
"application_name": "image_ingest",
# Connection settings
"keepalives": 1,
"keepalives_idle": 30,
"keepalives_interval": 10,
"keepalives_count": 5,
# Encoding settings
"client_encoding": "UTF8",
}
logger.info(
{
"event": "config_loaded",
"db_config": {
"dbname": DB_CONFIG["dbname"],
"user": DB_CONFIG["user"],
"host": DB_CONFIG["host"],
"port": DB_CONFIG["port"],
},
"is_development": IS_DEVELOPMENT,
}
)
# S3 configuration with optimized settings
S3_CONFIG = {
"aws_access_key_id": os.getenv("MINIO_ROOT_USER", "minioadmin")
if IS_DEVELOPMENT
else os.getenv("AWS_ACCESS_KEY_ID"),
"aws_secret_access_key": os.getenv("MINIO_ROOT_PASSWORD", "minioadmin")
if IS_DEVELOPMENT
else os.getenv("AWS_SECRET_ACCESS_KEY"),
"endpoint_url": "http://localhost:9000" if IS_DEVELOPMENT else None,
"region_name": os.getenv("AWS_REGION", "us-east-1"),
"verify": not IS_DEVELOPMENT,
"config": Config(
max_pool_connections=100,
connect_timeout=5,
read_timeout=5,
retries={"max_attempts": MAX_RETRIES},
tcp_keepalive=True,
),
}
class S3Client:
"""Handles S3 operations with connection pooling."""
def __init__(self):
self.client = boto3.client("s3", **S3_CONFIG)
self._session = None
self._init_bucket()
def _init_bucket(self):
"""Initialize S3 bucket with proper configuration."""
try:
# Create bucket if it doesn't exist
try:
self.client.head_bucket(Bucket=BUCKET_NAME)
except self.client.exceptions.ClientError:
self.client.create_bucket(Bucket=BUCKET_NAME)
# Set bucket policy for public read access
bucket_policy = {
"Version": "2012-10-17",
"Statement": [
{
"Sid": "PublicReadGetObject",
"Effect": "Allow",
"Principal": "*",
"Action": "s3:GetObject",
"Resource": f"arn:aws:s3:::{BUCKET_NAME}/*",
}
],
}
self.client.put_bucket_policy(
Bucket=BUCKET_NAME, Policy=json.dumps(bucket_policy)
)
logger.info({"event": "bucket_initialized", "bucket": BUCKET_NAME})
except Exception as e:
logger.error(
{
"event": "bucket_initialization_failed",
"error": str(e),
"error_type": type(e).__name__,
}
)
raise
async def get_session(self):
"""Get or create aiohttp session."""
if self._session is None:
timeout = aiohttp.ClientTimeout(total=30, connect=10)
self._session = aiohttp.ClientSession(timeout=timeout)
return self._session
async def upload_image(self, image_url: str, external_id: str) -> Optional[str]:
"""
Upload an image to S3 with optimized performance.
Args:
image_url: The URL of the image to upload
external_id: External ID for the image
Returns:
Optional[str]: The S3 URL of the uploaded image, or None if upload failed
"""
if not image_url or not external_id:
return None
key = f"images/{external_id}.jpg"
try:
session = await self.get_session()
async with session.get(image_url, ssl=False) as response:
if response.status != 200:
logger.error(
{
"event": "image_download_failed",
"status_code": response.status,
"url": image_url,
"external_id": external_id,
}
)
return None
content = await response.read()
if not content:
logger.error(
{
"event": "empty_image_content",
"url": image_url,
"external_id": external_id,
}
)
return None
# Upload to S3 with optimized settings
self.client.put_object(
Bucket=BUCKET_NAME,
Key=key,
Body=content,
ContentType="image/jpeg",
ACL="public-read",
StorageClass="STANDARD",
)
s3_url = (
f"http://localhost:9000/{BUCKET_NAME}/{key}"
if IS_DEVELOPMENT
else f"https://{BUCKET_NAME}.s3.{S3_CONFIG['region_name']}.amazonaws.com/{key}"
)
return s3_url
except Exception as e:
logger.error(
{
"event": "image_upload_failed",
"error": str(e),
"error_type": type(e).__name__,
"url": image_url,
"external_id": external_id,
}
)
return None
async def cleanup(self):
"""Cleanup resources."""
if self._session:
await self._session.close()
self._session = None
def bulk_insert_to_db(
conn: psycopg2.extensions.connection, records: List[Dict[str, Any]]
) -> None:
"""
Efficiently insert multiple records into PostgreSQL using pandas and COPY.
This approach handles JSON serialization automatically through pandas.
"""
if not records:
logger.warning({"event": "bulk_insert_skipped", "reason": "empty_records"})
return
try:
# Optimize PostgreSQL settings for bulk insert
with conn.cursor() as cur:
cur.execute("SET synchronous_commit = OFF")
cur.execute("SET maintenance_work_mem = '1GB'")
cur.execute("SET temp_buffers = '256MB'")
cur.execute("SET work_mem = '256MB'")
cur.execute("SET effective_io_concurrency = 8")
# Convert records to pandas DataFrame
df = pd.DataFrame(records)
# Process in smaller chunks for memory efficiency
chunk_size = 5000
for chunk_start in range(0, len(df), chunk_size):
chunk = df[chunk_start: chunk_start + chunk_size]
# Convert metadata and embedding columns to properly escaped JSON strings
chunk["metadata"] = chunk["metadata"].apply(
lambda x: json.dumps(x, ensure_ascii=False) if x else "{}"
)
chunk["embedding"] = chunk["embedding"].apply(
lambda x: json.dumps(x, ensure_ascii=False) if x else "[]"
)
# Convert tags list to PostgreSQL array format
chunk["tags"] = chunk["tags"].apply(
lambda x: "{" + ",".join(f'"{tag}"' for tag in (x or [])) + "}"
)
# Format timestamps
chunk["created_at"] = pd.to_datetime(chunk["created_at"]).dt.strftime(
"%Y-%m-%d %H:%M:%S.%f%z"
)
chunk["updated_at"] = pd.to_datetime(chunk["updated_at"]).dt.strftime(
"%Y-%m-%d %H:%M:%S.%f%z"
)
# Create a buffer for the COPY operation
buffer = StringIO()
# Write chunk to buffer in CSV format with tab separator
chunk.to_csv(
buffer,
sep="\t",
header=False,
index=False,
na_rep="\\N",
quoting=2,
escapechar="\\",
doublequote=False,
columns=[
"external_id",
"name",
"description",
"image_url",
"external_url",
"embedding",
"metadata",
"tags",
"created_at",
"updated_at",
],
)
# Move buffer cursor to start
buffer.seek(0)
# Execute COPY command with proper handling of special characters
with conn.cursor() as cur:
cur.copy_expert(
"""
COPY image (
external_id, name, description, image_url, external_url,
embedding, metadata, tags, created_at, updated_at
) FROM STDIN WITH (
FORMAT CSV,
DELIMITER E'\t',
NULL '\\N',
QUOTE '"',
ESCAPE '\\'
)
""",
buffer,
)
conn.commit()
logger.info(
{
"event": "chunk_insert_completed",
"chunk_size": len(chunk),
"start_index": chunk_start,
}
)
logger.info({"event": "bulk_insert_success", "total_records": len(df)})
except Exception as e:
conn.rollback()
logger.error(
{
"event": "bulk_insert_failed",
"error": str(e),
"error_type": type(e).__name__,
"traceback": traceback.format_exc(),
}
)
raise
def parse_bytes_json(value: Any, default_value: Any) -> Any:
"""Parse a bytes value as JSON, with a default value if parsing fails."""
try:
if isinstance(value, bytes):
try:
return json.loads(value.decode("utf-8"))
except json.JSONDecodeError:
# If it's not valid JSON, try to parse it as a string
value_str = value.decode("utf-8")
# Convert Python-style dict string to valid JSON
value_str = value_str.replace(
"'", '"'
) # Replace single quotes with double quotes
value_str = value_str.replace(
"None", "null"
) # Replace Python None with JSON null
try:
return json.loads(value_str)
except json.JSONDecodeError:
return default_value
return value or default_value
except Exception:
return default_value
async def create_db_pool(
min_size: int = 2, max_size: int = DB_MAX_CONNECTIONS
) -> SimpleConnectionPool:
"""
Create a database connection pool with the specified size limits.
Args:
min_size: Minimum number of connections to keep in the pool
max_size: Maximum number of connections allowed in the pool
Returns:
SimpleConnectionPool: A configured database connection pool
"""
try:
logger.info(
{
"event": "creating_db_pool",
"min_size": min_size,
"max_size": max_size,
"config": {
"host": DB_CONFIG["host"],
"port": DB_CONFIG["port"],
"dbname": DB_CONFIG["dbname"],
},
}
)
pool = SimpleConnectionPool(minconn=min_size, maxconn=max_size, **DB_CONFIG)
# Test the pool by getting and returning a connection
conn = pool.getconn()
with conn.cursor() as cur:
cur.execute("SELECT version();")
version = cur.fetchone()[0]
logger.info(
{"event": "db_pool_created", "version": version, "pool_size": max_size}
)
pool.putconn(conn)
return pool
except Exception as e:
logger.error(
{
"event": "db_pool_creation_failed",
"error": str(e),
"error_type": type(e).__name__,
"traceback": traceback.format_exc(),
}
)
raise
async def process_batch(
records: List[Dict],
s3_client: S3Client,
db_pool: SimpleConnectionPool,
batch_start_time: float,
) -> Tuple[int, int]:
"""Process a batch of records concurrently."""
successful_uploads = 0
failed_uploads = 0
conn = None
try:
# Pre-process records to prepare for bulk operations
processed_images = []
upload_tasks = []
# Create upload tasks
for record in records:
external_id = str(record.get("id", ""))
image_url = record.get("image")
if not image_url:
failed_uploads += 1
logger.warning(
{
"event": "skip_record",
"reason": "missing_image_url",
"external_id": external_id,
"image_url": image_url,
}
)
continue
upload_tasks.append(s3_client.upload_image(image_url, external_id))
# Wait for all uploads to complete
upload_results = await asyncio.gather(*upload_tasks, return_exceptions=True)
# Process results
for idx, s3_url in enumerate(upload_results):
record = records[idx]
external_id = str(record.get("id", ""))
if isinstance(s3_url, Exception):
failed_uploads += 1
continue
if s3_url:
cafe_data = parse_bytes_json(record.get("cafe"), {})
current_time = datetime.now(timezone.utc)
processed_images.append(
{
"external_id": external_id,
"name": record.get("name", ""),
"description": record.get("description", ""),
"image_url": s3_url,
"external_url": record.get("image"),
"embedding": parse_bytes_json(record.get("vector"), []),
"metadata": cafe_data,
"tags": cafe_data.get("categories", []),
"created_at": current_time,
"updated_at": current_time,
}
)
successful_uploads += 1
else:
failed_uploads += 1
# Bulk insert processed records
if processed_images:
conn = db_pool.getconn()
conn.set_isolation_level(ISOLATION_LEVEL_READ_COMMITTED)
bulk_insert_to_db(conn, processed_images)
# Log batch metrics
batch_duration = time.time() - batch_start_time
total_records = len(records)
logger.info(
{
"event": "batch_completed",
"successful_uploads": successful_uploads,
"failed_uploads": failed_uploads,
"batch_size": total_records,
"duration_seconds": batch_duration,
"records_per_second": total_records / batch_duration
if batch_duration > 0
else 0,
}
)
except Exception as e:
total_records = len(records)
logger.error(
{
"event": "batch_processing_failed",
"error": str(e),
"error_type": type(e).__name__,
"traceback": traceback.format_exc(),
"batch_size": total_records,
}
)
failed_uploads = total_records
successful_uploads = 0
finally:
if conn:
db_pool.putconn(conn)
return successful_uploads, failed_uploads
async def main():
"""Main entry point for the ingestion script."""
start_time = time.time()
total_processed = 0
total_successful = 0
total_failed = 0
try:
# Initialize clients
s3_client = S3Client()
db_pool = await create_db_pool(min_size=2, max_size=DB_MAX_CONNECTIONS)
# Load dataset
dataset = load_dataset(
"Qdrant/wolt-food-clip-ViT-B-32-embeddings", split="train", streaming=True
)
# Process records in batches
batch_size = BATCH_SIZE
current_batch = []
batch_start_time = time.time()
logger.info(
{
"event": "processing_started",
"batch_size": batch_size,
"max_records": MAX_RECORDS,
"streaming": True,
}
)
# Initialize progress bar
progress_bar = tqdm(
total=MAX_RECORDS,
desc="Processing images",
unit="img",
dynamic_ncols=True,
position=0,
)
progress_bar.set_postfix(
{"successful": 0, "failed": 0, "batch_size": batch_size}
)
for record in dataset:
# Check if we've reached or would exceed the maximum number of records
if total_successful >= MAX_RECORDS:
logger.info(
{
"event": "max_records_reached",
"total_processed": total_processed,
"total_successful": total_successful,
"total_failed": total_failed,
"max_records": MAX_RECORDS,
}
)
break
# Adjust batch size if we're approaching MAX_RECORDS
remaining = MAX_RECORDS - total_successful
if len(current_batch) >= min(batch_size, remaining):
successful, failed = await process_batch(
current_batch, s3_client, db_pool, batch_start_time
)
total_successful += successful
total_failed += failed
total_processed += len(current_batch)
# Update progress bar
progress_bar.update(successful)
progress_bar.set_postfix(
{
"successful": total_successful,
"failed": total_failed,
"batch_size": len(current_batch),
"remaining": MAX_RECORDS - total_successful,
}
)
# Reset for next batch
current_batch = []
batch_start_time = time.time()
elapsed_time = time.time() - start_time
logger.info(
{
"event": "progress_update",
"total_processed": total_processed,
"total_successful": total_successful,
"total_failed": total_failed,
"remaining": MAX_RECORDS - total_successful,
"elapsed_minutes": elapsed_time / 60,
"records_per_second": total_processed / elapsed_time
if elapsed_time > 0
else 0,
}
)
current_batch.append(record)
# Process remaining records only if we haven't reached MAX_RECORDS
if current_batch and total_successful < MAX_RECORDS:
successful, failed = await process_batch(
current_batch, s3_client, db_pool, batch_start_time
)
total_successful += successful
total_failed += failed
total_processed += len(current_batch)
# Update progress bar one last time
progress_bar.update(successful)
progress_bar.set_postfix(
{
"successful": total_successful,
"failed": total_failed,
"batch_size": len(current_batch),
}
)
# Close progress bar
progress_bar.close()
# Log final statistics
total_time = time.time() - start_time
logger.info(
{
"event": "processing_completed",
"total_processed": total_processed,
"total_successful": total_successful,
"total_failed": total_failed,
"max_records": MAX_RECORDS,
"total_minutes": total_time / 60,
"average_records_per_second": total_processed / total_time
if total_time > 0
else 0,
}
)
except Exception as e:
logger.error(
{
"event": "processing_failed",
"error": str(e),
"error_type": type(e).__name__,
"traceback": traceback.format_exc(),
}
)
raise
finally:
# Cleanup resources
try:
await s3_client.cleanup()
logger.info({"event": "cleanup_completed"})
except Exception as e:
logger.error(
{
"event": "cleanup_failed",
"error": str(e),
"error_type": type(e).__name__,
}
)
# Close database connections
if "db_pool" in locals():
db_pool.closeall()
if __name__ == "__main__":
asyncio.run(main())
version: '3.6'
x-project-name: &project-name ${COMPOSE_PROJECT_NAME:-projectx}
services:
db:
build:
context: .
dockerfile: Dockerfile.postgres
container_name: ${COMPOSE_PROJECT_NAME:-projectx}-db
restart: always
networks:
- default
volumes:
- postgres_data:/var/lib/postgresql/data
- ./docker-entrypoint-initdb.d:/docker-entrypoint-initdb.d
environment:
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
POSTGRES_USER: ${POSTGRES_USER}
POSTGRES_DB: ${POSTGRES_DB}
PGDATA: /var/lib/postgresql/data/pgdata # Explicit PGDATA path
ports:
- "${POSTGRES_PORT:-5432}:5432"
healthcheck:
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER} -d ${POSTGRES_DB}"]
interval: 10s
timeout: 5s
retries: 5
start_period: 2s # Added start period
labels:
- "com.*project-name.service=database"
expose:
- "${POSTGRES_PORT:-5432}"
tmpfs:
- /tmp
- /run
- /run/postgresql
# Add MinIO service before the other services
minio:
image: minio/minio:latest
container_name: ${COMPOSE_PROJECT_NAME:-projectx}-minio
environment:
MINIO_ROOT_USER: ${MINIO_ROOT_USER:-minioadmin}
MINIO_ROOT_PASSWORD: ${MINIO_ROOT_PASSWORD:-minioadmin}
MINIO_BROWSER_REDIRECT_URL: http://localhost:9001
command: server /data --console-address ":9001"
ports:
- "${MINIO_API_PORT:-9000}:9000" # API port
- "${MINIO_CONSOLE_PORT:-9001}:9001" # Console port
volumes:
- minio_data:/data
healthcheck:
test: ["CMD", "mc", "ready", "local"]
interval: 5s
timeout: 5s
retries: 5
networks:
- default
# MinIO setup service (creates bucket on startup)
minio-setup:
image: minio/mc
container_name: ${COMPOSE_PROJECT_NAME:-projectx}-minio-setup
depends_on:
minio:
condition: service_healthy
entrypoint: >
/bin/sh -c "
/usr/bin/mc config host add myminio http://minio:9000 minioadmin minioadmin;
/usr/bin/mc mb myminio/${MINIO_BUCKET_NAME:-projectx-images} --ignore-existing;
/usr/bin/mc policy set public myminio/${MINIO_BUCKET_NAME:-projectx-images};
exit 0;
"
networks:
- default
networks:
default:
name: ${COMPOSE_PROJECT_NAME:-projectx}-network
driver: bridge
volumes:
postgres_data:
name: ${COMPOSE_PROJECT_NAME:-projectx}-postgres-data
minio_data:
name: ${COMPOSE_PROJECT_NAME:-projectx}-minio-data
-- Enable pg_trgm for text search
CREATE EXTENSION IF NOT EXISTS pg_trgm;
-- Install pgvector for vector operations
CREATE EXTENSION IF NOT EXISTS vector;
-- Enable PostGIS
CREATE EXTENSION IF NOT EXISTS postgis;
# Use official Postgres image with multi-arch support
FROM postgres:17-bullseye
# Install required packages and build pgvector
RUN apt-get update \
&& apt-get install -y --no-install-recommends \
# SSL certificates
ca-certificates \
# PostGIS
postgis \
postgresql-17-postgis-3 \
# pgvector build dependencies
build-essential \
git \
postgresql-server-dev-17 \
# Update certificates
&& update-ca-certificates \
# Install pgvector
&& git clone --branch v0.8.0 https://github.com/pgvector/pgvector.git \
&& cd pgvector \
&& make \
&& make install \
# Cleanup
&& cd .. \
&& rm -rf pgvector \
&& apt-get remove -y build-essential git postgresql-server-dev-17 \
&& apt-get autoremove -y \
&& rm -rf /var/lib/apt/lists/*
# Add initialization scripts
COPY ./docker-entrypoint-initdb.d /docker-entrypoint-initdb.d/
CREATE EXTENSION IF NOT EXISTS "vector";
CREATE TABLE "image" (
"id" SERIAL NOT NULL,
"external_id" VARCHAR(100),
"name" VARCHAR(255) NOT NULL,
"description" TEXT,
"image_url" TEXT NOT NULL,
"external_url" TEXT,
"embedding" vector(512),
"metadata" JSONB,
"tags" VARCHAR(50)[],
"created_at" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMP(3) NOT NULL,
CONSTRAINT "image_pkey" PRIMARY KEY ("id")
);
CREATE INDEX "ix_image_name" ON "image"("name");
CREATE INDEX "ix_image_tags" ON "image" USING GIN ("tags");
-- Vector Search Index
CREATE INDEX "ix_image_embedding" ON "image"
USING hnsw (embedding vector_cosine_ops)
WITH (
m = 16, -- Number of connections per node
ef_construction = 64 -- Build-time exploration factor
);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment