Skip to content

Instantly share code, notes, and snippets.

@jdnichollsc
Last active February 10, 2025 21:51
Show Gist options
  • Save jdnichollsc/cd2345bbb53cbdcb13b97b3f3565e616 to your computer and use it in GitHub Desktop.
Save jdnichollsc/cd2345bbb53cbdcb13b97b3f3565e616 to your computer and use it in GitHub Desktop.
Data ingestion using Ray actors with Hugging Face datasets
"""
Fast image ingestion tool for Wolt Food dataset.
This script efficiently processes the Wolt Food CLIP-ViT-B-32 dataset,
uploading images to S3 and storing metadata and embeddings in PostgreSQL.
Optimized for processing 100k+ records using:
- Parallel image downloads and S3 uploads with connection pooling
- PostgreSQL COPY for bulk inserts
- Ray actors for efficient resource management
"""
import os
import ray
import boto3
import psycopg2
import json
import logging
import aiohttp
import traceback
import time
import psutil
import asyncio
from io import StringIO
from datasets import load_dataset
from botocore.config import Config
from typing import Dict, List, Any, Optional, Tuple
import pandas as pd
from datetime import datetime, timezone
from psycopg2.pool import SimpleConnectionPool
from psycopg2.extensions import ISOLATION_LEVEL_READ_COMMITTED
# Configure structured JSON logging
logging.basicConfig(
level=logging.INFO,
format='{"timestamp": "%(asctime)s", "level": "%(levelname)s", "message": %(message)s}',
force=True,
)
logger = logging.getLogger(__name__)
def get_optimal_batch_size():
"""Calculate optimal batch size based on available system memory."""
available_memory = psutil.virtual_memory().available
# Estimate 2MB per image on average (increased from 1MB)
estimated_image_size = 2 * 1024 * 1024 # 2MB in bytes
# Use 40% of available memory (increased from 20%)
memory_limit = available_memory * 0.4
optimal_size = int(memory_limit / estimated_image_size)
# Increased limits for better throughput
return max(500, min(optimal_size, 2000))
# Maximum number of records to process from the dataset
MAX_RECORDS = 100000 # Adjust this value to limit the number of records to process
# Performance tuning constants
BATCH_SIZE = get_optimal_batch_size()
MAX_RETRIES = 2
NUM_S3_ACTORS = min(32, psutil.cpu_count() * 2) # Scale with available CPUs
# Database performance settings
DB_MAX_CONNECTIONS = min(100, psutil.cpu_count() * 4) # Scale with CPUs
logger.info(
{
"event": "performance_config",
"batch_size": BATCH_SIZE,
"s3_actors": NUM_S3_ACTORS,
"available_memory": psutil.virtual_memory().available
/ (1024 * 1024 * 1024), # GB
"cpu_cores": psutil.cpu_count(),
}
)
# Environment and configuration
IS_DEVELOPMENT = os.getenv("NODE_ENV") != "production"
BUCKET_NAME = (
os.getenv("MINIO_BUCKET_NAME", "projectx-images")
if IS_DEVELOPMENT
else os.getenv("AWS_BUCKET_NAME")
)
# Database configuration
DB_CONFIG = {
"dbname": os.getenv("POSTGRES_DB", "projectx"),
"user": os.getenv("POSTGRES_USER", "postgres"),
"password": os.getenv("POSTGRES_PASSWORD", "postgres"),
"host": "localhost",
"port": os.getenv("POSTGRES_PORT", "5432"),
"application_name": "image_ingest",
# Connection settings
"keepalives": 1,
"keepalives_idle": 30,
"keepalives_interval": 10,
"keepalives_count": 5,
# Encoding settings
"client_encoding": "UTF8",
}
logger.info(
{
"event": "config_loaded",
"db_config": {
"dbname": DB_CONFIG["dbname"],
"user": DB_CONFIG["user"],
"host": DB_CONFIG["host"],
"port": DB_CONFIG["port"],
},
"is_development": IS_DEVELOPMENT,
}
)
# S3 configuration with optimized settings
S3_CONFIG = {
"aws_access_key_id": os.getenv("MINIO_ROOT_USER", "minioadmin")
if IS_DEVELOPMENT
else os.getenv("AWS_ACCESS_KEY_ID"),
"aws_secret_access_key": os.getenv("MINIO_ROOT_PASSWORD", "minioadmin")
if IS_DEVELOPMENT
else os.getenv("AWS_SECRET_ACCESS_KEY"),
"endpoint_url": "http://localhost:9000" if IS_DEVELOPMENT else None,
"region_name": os.getenv("AWS_REGION", "us-east-1"),
"verify": not IS_DEVELOPMENT,
"config": Config(
max_pool_connections=100,
connect_timeout=5,
read_timeout=5,
retries={"max_attempts": MAX_RETRIES},
tcp_keepalive=True,
),
}
@ray.remote
class S3Actor:
"""Actor for handling S3 operations with optimized performance."""
def __init__(self):
"""Initialize the S3 actor with optimized settings."""
# Initialize S3 client with existing configuration
self.client = boto3.client("s3", **S3_CONFIG)
# Initialize aiohttp session
self._session = None
# Initialize the bucket with proper configuration
self._init_bucket()
async def _get_session(self):
"""Get or create aiohttp session."""
if self._session is None:
timeout = aiohttp.ClientTimeout(total=30, connect=10)
self._session = aiohttp.ClientSession(timeout=timeout)
return self._session
def _init_bucket(self):
"""Initialize S3 bucket with proper configuration."""
try:
# Create bucket if it doesn't exist
try:
self.client.head_bucket(Bucket=BUCKET_NAME)
except self.client.exceptions.ClientError:
self.client.create_bucket(Bucket=BUCKET_NAME)
# Set bucket policy for public read access
bucket_policy = {
"Version": "2012-10-17",
"Statement": [
{
"Sid": "PublicReadGetObject",
"Effect": "Allow",
"Principal": "*",
"Action": "s3:GetObject",
"Resource": f"arn:aws:s3:::{BUCKET_NAME}/*",
}
],
}
self.client.put_bucket_policy(
Bucket=BUCKET_NAME, Policy=json.dumps(bucket_policy)
)
logger.info({"event": "bucket_initialized", "bucket": BUCKET_NAME})
except Exception as e:
logger.error(
{
"event": "bucket_initialization_failed",
"error": str(e),
"error_type": type(e).__name__,
}
)
raise
async def upload_image(self, image_url: str, external_id: str) -> Optional[str]:
"""
Upload an image to S3 with optimized performance.
Args:
image_url: The URL of the image to upload
external_id: External ID for the image
Returns:
Optional[str]: The S3 URL of the uploaded image, or None if upload failed
"""
if not image_url or not external_id:
return None
key = f"images/{external_id}.jpg"
upload_start_time = time.time()
try:
# Get aiohttp session
session = await self._get_session()
# Download image with timeout and optimized settings
async with session.get(image_url, ssl=False) as response:
if response.status != 200:
logger.error(
{
"event": "image_download_failed",
"status_code": response.status,
"url": image_url,
"external_id": external_id,
}
)
return None
content = await response.read()
if not content:
logger.error(
{
"event": "empty_image_content",
"url": image_url,
"external_id": external_id,
}
)
return None
# Upload to S3 with optimized settings
self.client.put_object(
Bucket=BUCKET_NAME,
Key=key,
Body=content,
ContentType="image/jpeg",
ACL="public-read",
StorageClass="STANDARD",
)
# Generate appropriate S3 URL based on environment
if IS_DEVELOPMENT:
s3_url = f"http://localhost:9000/{BUCKET_NAME}/{key}"
else:
# For production S3, construct the direct URL
s3_url = f"https://{BUCKET_NAME}.s3.{S3_CONFIG['region_name']}.amazonaws.com/{key}"
upload_duration = time.time() - upload_start_time
logger.info(
{
"event": "image_uploaded",
"external_id": external_id,
"s3_url": s3_url,
"size_bytes": len(content),
"duration_seconds": upload_duration,
}
)
return s3_url
except Exception as e:
logger.error(
{
"event": "image_upload_failed",
"error": str(e),
"error_type": type(e).__name__,
"url": image_url,
"external_id": external_id,
}
)
return None
async def cleanup(self):
"""Cleanup resources explicitly."""
try:
if hasattr(self, "_session") and self._session:
await self._session.close()
self._session = None
logger.info({"event": "s3_actor_cleanup", "status": "completed"})
except Exception as e:
logger.error(
{
"event": "s3_actor_cleanup",
"status": "failed",
"error": str(e),
"error_type": type(e).__name__,
}
)
def __del__(self):
"""Safe destructor for Ray actor."""
try:
if hasattr(self, "_session") and self._session:
logger.warning(
{
"event": "s3_actor_cleanup",
"status": "session_not_closed",
"message": "S3Actor was destroyed without calling cleanup()",
}
)
except Exception:
# Silently ignore any errors in destructor
pass
def bulk_insert_to_db(
conn: psycopg2.extensions.connection, records: List[Dict[str, Any]]
) -> None:
"""
Efficiently insert multiple records into PostgreSQL using pandas and COPY.
This approach handles JSON serialization automatically through pandas.
"""
if not records:
logger.warning({"event": "bulk_insert_skipped", "reason": "empty_records"})
return
try:
# Optimize PostgreSQL settings for bulk insert
with conn.cursor() as cur:
cur.execute("SET synchronous_commit = OFF")
cur.execute("SET maintenance_work_mem = '1GB'")
cur.execute("SET temp_buffers = '256MB'")
cur.execute("SET work_mem = '256MB'")
cur.execute("SET effective_io_concurrency = 8")
# Convert records to pandas DataFrame
df = pd.DataFrame(records)
# Process in smaller chunks for memory efficiency
chunk_size = 5000
for chunk_start in range(0, len(df), chunk_size):
chunk = df[chunk_start : chunk_start + chunk_size]
# Convert metadata and embedding columns to properly escaped JSON strings
chunk["metadata"] = chunk["metadata"].apply(
lambda x: json.dumps(x, ensure_ascii=False) if x else "{}"
)
chunk["embedding"] = chunk["embedding"].apply(
lambda x: json.dumps(x, ensure_ascii=False) if x else "[]"
)
# Convert tags list to PostgreSQL array format
chunk["tags"] = chunk["tags"].apply(
lambda x: "{" + ",".join(f'"{tag}"' for tag in (x or [])) + "}"
)
# Format timestamps
chunk["created_at"] = pd.to_datetime(chunk["created_at"]).dt.strftime(
"%Y-%m-%d %H:%M:%S.%f%z"
)
chunk["updated_at"] = pd.to_datetime(chunk["updated_at"]).dt.strftime(
"%Y-%m-%d %H:%M:%S.%f%z"
)
# Create a buffer for the COPY operation
buffer = StringIO()
# Write chunk to buffer in CSV format with tab separator
chunk.to_csv(
buffer,
sep="\t",
header=False,
index=False,
na_rep="\\N",
quoting=2,
escapechar="\\",
doublequote=False,
columns=[
"external_id",
"name",
"description",
"image_url",
"external_url",
"embedding",
"metadata",
"tags",
"created_at",
"updated_at",
],
)
# Move buffer cursor to start
buffer.seek(0)
# Execute COPY command with proper handling of special characters
with conn.cursor() as cur:
cur.copy_expert(
"""
COPY image (
external_id, name, description, image_url, external_url,
embedding, metadata, tags, created_at, updated_at
) FROM STDIN WITH (
FORMAT CSV,
DELIMITER E'\t',
NULL '\\N',
QUOTE '"',
ESCAPE '\\'
)
""",
buffer,
)
conn.commit()
logger.info(
{
"event": "chunk_insert_completed",
"chunk_size": len(chunk),
"start_index": chunk_start,
}
)
logger.info({"event": "bulk_insert_success", "total_records": len(df)})
except Exception as e:
conn.rollback()
logger.error(
{
"event": "bulk_insert_failed",
"error": str(e),
"error_type": type(e).__name__,
"traceback": traceback.format_exc(),
}
)
raise
def parse_bytes_json(value: Any, default_value: Any) -> Any:
"""Parse a bytes value as JSON, with a default value if parsing fails."""
try:
if isinstance(value, bytes):
try:
return json.loads(value.decode("utf-8"))
except json.JSONDecodeError:
# If it's not valid JSON, try to parse it as a string
value_str = value.decode("utf-8")
# Convert Python-style dict string to valid JSON
value_str = value_str.replace(
"'", '"'
) # Replace single quotes with double quotes
value_str = value_str.replace(
"None", "null"
) # Replace Python None with JSON null
try:
return json.loads(value_str)
except json.JSONDecodeError:
return default_value
return value or default_value
except Exception:
return default_value
async def create_db_pool(
min_size: int = 2, max_size: int = DB_MAX_CONNECTIONS
) -> SimpleConnectionPool:
"""
Create a database connection pool with the specified size limits.
Args:
min_size: Minimum number of connections to keep in the pool
max_size: Maximum number of connections allowed in the pool
Returns:
SimpleConnectionPool: A configured database connection pool
"""
try:
logger.info(
{
"event": "creating_db_pool",
"min_size": min_size,
"max_size": max_size,
"config": {
"host": DB_CONFIG["host"],
"port": DB_CONFIG["port"],
"dbname": DB_CONFIG["dbname"],
},
}
)
pool = SimpleConnectionPool(minconn=min_size, maxconn=max_size, **DB_CONFIG)
# Test the pool by getting and returning a connection
conn = pool.getconn()
with conn.cursor() as cur:
cur.execute("SELECT version();")
version = cur.fetchone()[0]
logger.info(
{"event": "db_pool_created", "version": version, "pool_size": max_size}
)
pool.putconn(conn)
return pool
except Exception as e:
logger.error(
{
"event": "db_pool_creation_failed",
"error": str(e),
"error_type": type(e).__name__,
"traceback": traceback.format_exc(),
}
)
raise
async def process_records_batch(
records: List[Dict],
s3_actor: S3Actor,
db_pool: SimpleConnectionPool,
batch_start_time: float,
) -> Tuple[int, int]:
"""Process a batch of records, uploading images to S3 and inserting data into PostgreSQL."""
successful_uploads = 0
failed_uploads = 0
conn = None
try:
# Pre-process records to prepare for bulk operations
processed_images = []
upload_tasks = []
for record in records:
external_id = str(record.get("id", ""))
try:
# Prepare upload task
image_url = record.get("image")
if not image_url:
failed_uploads += 1
logger.warning(
{
"event": "skip_record",
"reason": "missing_image_url",
"external_id": external_id,
"image_url": image_url,
}
)
continue
# Create upload task
upload_task = s3_actor.upload_image.remote(image_url, external_id)
upload_tasks.append((record, upload_task))
except Exception as e:
failed_uploads += 1
logger.error(
{
"event": "record_processing_failed",
"error": str(e),
"external_id": external_id,
"error_type": type(e).__name__,
}
)
# Wait for all upload tasks to complete
for record, upload_task in upload_tasks:
try:
external_id = str(record.get("id", ""))
s3_url = await upload_task
if s3_url:
# Prepare record for database insertion
cafe_data = parse_bytes_json(record.get("cafe"), {})
current_time = datetime.now(timezone.utc)
processed_images.append(
{
"external_id": external_id,
"name": record.get("name", ""),
"description": record.get("description", ""),
"image_url": s3_url,
"external_url": image_url,
"embedding": parse_bytes_json(record.get("vector"), []),
"metadata": cafe_data,
"tags": cafe_data.get("categories", []),
"created_at": current_time,
"updated_at": current_time,
}
)
successful_uploads += 1
else:
failed_uploads += 1
except Exception as e:
failed_uploads += 1
logger.error(
{
"event": "upload_failed",
"error": str(e),
"external_id": external_id,
"error_type": type(e).__name__,
}
)
# Bulk insert processed records
if processed_images:
conn = db_pool.getconn()
conn.set_isolation_level(ISOLATION_LEVEL_READ_COMMITTED)
bulk_insert_to_db(conn, processed_images)
# Log batch metrics
batch_duration = time.time() - batch_start_time
total_records = len(records)
logger.info(
{
"event": "batch_completed",
"successful_uploads": successful_uploads,
"failed_uploads": failed_uploads,
"batch_size": total_records,
"duration_seconds": batch_duration,
"records_per_second": total_records / batch_duration
if batch_duration > 0
else 0,
}
)
except Exception as e:
total_records = len(records)
logger.error(
{
"event": "batch_processing_failed",
"error": str(e),
"error_type": type(e).__name__,
"traceback": traceback.format_exc(),
"batch_size": total_records,
}
)
failed_uploads = total_records
successful_uploads = 0
finally:
if conn:
db_pool.putconn(conn)
return successful_uploads, failed_uploads
async def main():
"""Main entry point for the ingestion script."""
start_time = time.time()
total_processed = 0
total_successful = 0
total_failed = 0
s3_actors = []
try:
# Initialize Ray if not already running
if not ray.is_initialized():
ray.init()
# Create S3 actors pool
s3_actors = [S3Actor.remote() for _ in range(NUM_S3_ACTORS)]
current_actor = 0
# Create database connection pool
db_pool = await create_db_pool(min_size=2, max_size=DB_MAX_CONNECTIONS)
# Load and process dataset
dataset = load_dataset(
"Qdrant/wolt-food-clip-ViT-B-32-embeddings", split="train", streaming=True
)
# Calculate optimal batch size
batch_size = BATCH_SIZE
logger.info(
{
"event": "processing_started",
"batch_size": batch_size,
"num_s3_actors": NUM_S3_ACTORS,
"max_records": MAX_RECORDS,
"streaming": True,
}
)
# Process records in batches
current_batch = []
batch_start_time = time.time()
for record in dataset:
# Check if we've reached or would exceed the maximum number of records
if total_successful >= MAX_RECORDS:
logger.info(
{
"event": "max_records_reached",
"total_processed": total_processed,
"total_successful": total_successful,
"total_failed": total_failed,
"max_records": MAX_RECORDS,
}
)
break
# Adjust batch size if we're approaching MAX_RECORDS
remaining = MAX_RECORDS - total_successful
if len(current_batch) >= min(batch_size, remaining):
# Process batch using round-robin S3 actor assignment
s3_actor = s3_actors[current_actor]
current_actor = (current_actor + 1) % NUM_S3_ACTORS
successful, failed = await process_records_batch(
current_batch, s3_actor, db_pool, batch_start_time
)
total_successful += successful
total_failed += failed
total_processed += len(current_batch)
# Reset for next batch
current_batch = []
batch_start_time = time.time()
# Log progress
elapsed_time = time.time() - start_time
logger.info(
{
"event": "progress_update",
"total_processed": total_processed,
"total_successful": total_successful,
"total_failed": total_failed,
"remaining": MAX_RECORDS - total_successful,
"elapsed_minutes": elapsed_time / 60,
"records_per_second": total_processed / elapsed_time
if elapsed_time > 0
else 0,
}
)
# Break if we've reached the limit
if total_successful >= MAX_RECORDS:
break
current_batch.append(record)
# Process remaining records only if we haven't reached MAX_RECORDS
if current_batch and total_successful < MAX_RECORDS:
s3_actor = s3_actors[current_actor]
successful, failed = await process_records_batch(
current_batch, s3_actor, db_pool, batch_start_time
)
total_successful += successful
total_failed += failed
total_processed += len(current_batch)
# Log final statistics
total_time = time.time() - start_time
logger.info(
{
"event": "processing_completed",
"total_processed": total_processed,
"total_successful": total_successful,
"total_failed": total_failed,
"max_records": MAX_RECORDS,
"total_minutes": total_time / 60,
"average_records_per_second": total_processed / total_time
if total_time > 0
else 0,
}
)
except Exception as e:
logger.error(
{
"event": "processing_failed",
"error": str(e),
"error_type": type(e).__name__,
"traceback": traceback.format_exc(),
}
)
raise
finally:
# Cleanup S3 actors
cleanup_tasks = [actor.cleanup.remote() for actor in s3_actors]
try:
await asyncio.gather(*cleanup_tasks)
logger.info(
{"event": "cleanup_completed", "actors_cleaned": len(s3_actors)}
)
except Exception as e:
logger.error(
{
"event": "cleanup_failed",
"error": str(e),
"error_type": type(e).__name__,
}
)
# Shutdown Ray
if ray.is_initialized():
ray.shutdown()
# Close database connections
if "db_pool" in locals():
db_pool.closeall()
if __name__ == "__main__":
asyncio.run(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment