Last active
February 10, 2025 21:51
-
-
Save jdnichollsc/cd2345bbb53cbdcb13b97b3f3565e616 to your computer and use it in GitHub Desktop.
Data ingestion using Ray actors with Hugging Face datasets
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Fast image ingestion tool for Wolt Food dataset. | |
This script efficiently processes the Wolt Food CLIP-ViT-B-32 dataset, | |
uploading images to S3 and storing metadata and embeddings in PostgreSQL. | |
Optimized for processing 100k+ records using: | |
- Parallel image downloads and S3 uploads with connection pooling | |
- PostgreSQL COPY for bulk inserts | |
- Ray actors for efficient resource management | |
""" | |
import os | |
import ray | |
import boto3 | |
import psycopg2 | |
import json | |
import logging | |
import aiohttp | |
import traceback | |
import time | |
import psutil | |
import asyncio | |
from io import StringIO | |
from datasets import load_dataset | |
from botocore.config import Config | |
from typing import Dict, List, Any, Optional, Tuple | |
import pandas as pd | |
from datetime import datetime, timezone | |
from psycopg2.pool import SimpleConnectionPool | |
from psycopg2.extensions import ISOLATION_LEVEL_READ_COMMITTED | |
# Configure structured JSON logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='{"timestamp": "%(asctime)s", "level": "%(levelname)s", "message": %(message)s}', | |
force=True, | |
) | |
logger = logging.getLogger(__name__) | |
def get_optimal_batch_size(): | |
"""Calculate optimal batch size based on available system memory.""" | |
available_memory = psutil.virtual_memory().available | |
# Estimate 2MB per image on average (increased from 1MB) | |
estimated_image_size = 2 * 1024 * 1024 # 2MB in bytes | |
# Use 40% of available memory (increased from 20%) | |
memory_limit = available_memory * 0.4 | |
optimal_size = int(memory_limit / estimated_image_size) | |
# Increased limits for better throughput | |
return max(500, min(optimal_size, 2000)) | |
# Maximum number of records to process from the dataset | |
MAX_RECORDS = 100000 # Adjust this value to limit the number of records to process | |
# Performance tuning constants | |
BATCH_SIZE = get_optimal_batch_size() | |
MAX_RETRIES = 2 | |
NUM_S3_ACTORS = min(32, psutil.cpu_count() * 2) # Scale with available CPUs | |
# Database performance settings | |
DB_MAX_CONNECTIONS = min(100, psutil.cpu_count() * 4) # Scale with CPUs | |
logger.info( | |
{ | |
"event": "performance_config", | |
"batch_size": BATCH_SIZE, | |
"s3_actors": NUM_S3_ACTORS, | |
"available_memory": psutil.virtual_memory().available | |
/ (1024 * 1024 * 1024), # GB | |
"cpu_cores": psutil.cpu_count(), | |
} | |
) | |
# Environment and configuration | |
IS_DEVELOPMENT = os.getenv("NODE_ENV") != "production" | |
BUCKET_NAME = ( | |
os.getenv("MINIO_BUCKET_NAME", "projectx-images") | |
if IS_DEVELOPMENT | |
else os.getenv("AWS_BUCKET_NAME") | |
) | |
# Database configuration | |
DB_CONFIG = { | |
"dbname": os.getenv("POSTGRES_DB", "projectx"), | |
"user": os.getenv("POSTGRES_USER", "postgres"), | |
"password": os.getenv("POSTGRES_PASSWORD", "postgres"), | |
"host": "localhost", | |
"port": os.getenv("POSTGRES_PORT", "5432"), | |
"application_name": "image_ingest", | |
# Connection settings | |
"keepalives": 1, | |
"keepalives_idle": 30, | |
"keepalives_interval": 10, | |
"keepalives_count": 5, | |
# Encoding settings | |
"client_encoding": "UTF8", | |
} | |
logger.info( | |
{ | |
"event": "config_loaded", | |
"db_config": { | |
"dbname": DB_CONFIG["dbname"], | |
"user": DB_CONFIG["user"], | |
"host": DB_CONFIG["host"], | |
"port": DB_CONFIG["port"], | |
}, | |
"is_development": IS_DEVELOPMENT, | |
} | |
) | |
# S3 configuration with optimized settings | |
S3_CONFIG = { | |
"aws_access_key_id": os.getenv("MINIO_ROOT_USER", "minioadmin") | |
if IS_DEVELOPMENT | |
else os.getenv("AWS_ACCESS_KEY_ID"), | |
"aws_secret_access_key": os.getenv("MINIO_ROOT_PASSWORD", "minioadmin") | |
if IS_DEVELOPMENT | |
else os.getenv("AWS_SECRET_ACCESS_KEY"), | |
"endpoint_url": "http://localhost:9000" if IS_DEVELOPMENT else None, | |
"region_name": os.getenv("AWS_REGION", "us-east-1"), | |
"verify": not IS_DEVELOPMENT, | |
"config": Config( | |
max_pool_connections=100, | |
connect_timeout=5, | |
read_timeout=5, | |
retries={"max_attempts": MAX_RETRIES}, | |
tcp_keepalive=True, | |
), | |
} | |
@ray.remote | |
class S3Actor: | |
"""Actor for handling S3 operations with optimized performance.""" | |
def __init__(self): | |
"""Initialize the S3 actor with optimized settings.""" | |
# Initialize S3 client with existing configuration | |
self.client = boto3.client("s3", **S3_CONFIG) | |
# Initialize aiohttp session | |
self._session = None | |
# Initialize the bucket with proper configuration | |
self._init_bucket() | |
async def _get_session(self): | |
"""Get or create aiohttp session.""" | |
if self._session is None: | |
timeout = aiohttp.ClientTimeout(total=30, connect=10) | |
self._session = aiohttp.ClientSession(timeout=timeout) | |
return self._session | |
def _init_bucket(self): | |
"""Initialize S3 bucket with proper configuration.""" | |
try: | |
# Create bucket if it doesn't exist | |
try: | |
self.client.head_bucket(Bucket=BUCKET_NAME) | |
except self.client.exceptions.ClientError: | |
self.client.create_bucket(Bucket=BUCKET_NAME) | |
# Set bucket policy for public read access | |
bucket_policy = { | |
"Version": "2012-10-17", | |
"Statement": [ | |
{ | |
"Sid": "PublicReadGetObject", | |
"Effect": "Allow", | |
"Principal": "*", | |
"Action": "s3:GetObject", | |
"Resource": f"arn:aws:s3:::{BUCKET_NAME}/*", | |
} | |
], | |
} | |
self.client.put_bucket_policy( | |
Bucket=BUCKET_NAME, Policy=json.dumps(bucket_policy) | |
) | |
logger.info({"event": "bucket_initialized", "bucket": BUCKET_NAME}) | |
except Exception as e: | |
logger.error( | |
{ | |
"event": "bucket_initialization_failed", | |
"error": str(e), | |
"error_type": type(e).__name__, | |
} | |
) | |
raise | |
async def upload_image(self, image_url: str, external_id: str) -> Optional[str]: | |
""" | |
Upload an image to S3 with optimized performance. | |
Args: | |
image_url: The URL of the image to upload | |
external_id: External ID for the image | |
Returns: | |
Optional[str]: The S3 URL of the uploaded image, or None if upload failed | |
""" | |
if not image_url or not external_id: | |
return None | |
key = f"images/{external_id}.jpg" | |
upload_start_time = time.time() | |
try: | |
# Get aiohttp session | |
session = await self._get_session() | |
# Download image with timeout and optimized settings | |
async with session.get(image_url, ssl=False) as response: | |
if response.status != 200: | |
logger.error( | |
{ | |
"event": "image_download_failed", | |
"status_code": response.status, | |
"url": image_url, | |
"external_id": external_id, | |
} | |
) | |
return None | |
content = await response.read() | |
if not content: | |
logger.error( | |
{ | |
"event": "empty_image_content", | |
"url": image_url, | |
"external_id": external_id, | |
} | |
) | |
return None | |
# Upload to S3 with optimized settings | |
self.client.put_object( | |
Bucket=BUCKET_NAME, | |
Key=key, | |
Body=content, | |
ContentType="image/jpeg", | |
ACL="public-read", | |
StorageClass="STANDARD", | |
) | |
# Generate appropriate S3 URL based on environment | |
if IS_DEVELOPMENT: | |
s3_url = f"http://localhost:9000/{BUCKET_NAME}/{key}" | |
else: | |
# For production S3, construct the direct URL | |
s3_url = f"https://{BUCKET_NAME}.s3.{S3_CONFIG['region_name']}.amazonaws.com/{key}" | |
upload_duration = time.time() - upload_start_time | |
logger.info( | |
{ | |
"event": "image_uploaded", | |
"external_id": external_id, | |
"s3_url": s3_url, | |
"size_bytes": len(content), | |
"duration_seconds": upload_duration, | |
} | |
) | |
return s3_url | |
except Exception as e: | |
logger.error( | |
{ | |
"event": "image_upload_failed", | |
"error": str(e), | |
"error_type": type(e).__name__, | |
"url": image_url, | |
"external_id": external_id, | |
} | |
) | |
return None | |
async def cleanup(self): | |
"""Cleanup resources explicitly.""" | |
try: | |
if hasattr(self, "_session") and self._session: | |
await self._session.close() | |
self._session = None | |
logger.info({"event": "s3_actor_cleanup", "status": "completed"}) | |
except Exception as e: | |
logger.error( | |
{ | |
"event": "s3_actor_cleanup", | |
"status": "failed", | |
"error": str(e), | |
"error_type": type(e).__name__, | |
} | |
) | |
def __del__(self): | |
"""Safe destructor for Ray actor.""" | |
try: | |
if hasattr(self, "_session") and self._session: | |
logger.warning( | |
{ | |
"event": "s3_actor_cleanup", | |
"status": "session_not_closed", | |
"message": "S3Actor was destroyed without calling cleanup()", | |
} | |
) | |
except Exception: | |
# Silently ignore any errors in destructor | |
pass | |
def bulk_insert_to_db( | |
conn: psycopg2.extensions.connection, records: List[Dict[str, Any]] | |
) -> None: | |
""" | |
Efficiently insert multiple records into PostgreSQL using pandas and COPY. | |
This approach handles JSON serialization automatically through pandas. | |
""" | |
if not records: | |
logger.warning({"event": "bulk_insert_skipped", "reason": "empty_records"}) | |
return | |
try: | |
# Optimize PostgreSQL settings for bulk insert | |
with conn.cursor() as cur: | |
cur.execute("SET synchronous_commit = OFF") | |
cur.execute("SET maintenance_work_mem = '1GB'") | |
cur.execute("SET temp_buffers = '256MB'") | |
cur.execute("SET work_mem = '256MB'") | |
cur.execute("SET effective_io_concurrency = 8") | |
# Convert records to pandas DataFrame | |
df = pd.DataFrame(records) | |
# Process in smaller chunks for memory efficiency | |
chunk_size = 5000 | |
for chunk_start in range(0, len(df), chunk_size): | |
chunk = df[chunk_start : chunk_start + chunk_size] | |
# Convert metadata and embedding columns to properly escaped JSON strings | |
chunk["metadata"] = chunk["metadata"].apply( | |
lambda x: json.dumps(x, ensure_ascii=False) if x else "{}" | |
) | |
chunk["embedding"] = chunk["embedding"].apply( | |
lambda x: json.dumps(x, ensure_ascii=False) if x else "[]" | |
) | |
# Convert tags list to PostgreSQL array format | |
chunk["tags"] = chunk["tags"].apply( | |
lambda x: "{" + ",".join(f'"{tag}"' for tag in (x or [])) + "}" | |
) | |
# Format timestamps | |
chunk["created_at"] = pd.to_datetime(chunk["created_at"]).dt.strftime( | |
"%Y-%m-%d %H:%M:%S.%f%z" | |
) | |
chunk["updated_at"] = pd.to_datetime(chunk["updated_at"]).dt.strftime( | |
"%Y-%m-%d %H:%M:%S.%f%z" | |
) | |
# Create a buffer for the COPY operation | |
buffer = StringIO() | |
# Write chunk to buffer in CSV format with tab separator | |
chunk.to_csv( | |
buffer, | |
sep="\t", | |
header=False, | |
index=False, | |
na_rep="\\N", | |
quoting=2, | |
escapechar="\\", | |
doublequote=False, | |
columns=[ | |
"external_id", | |
"name", | |
"description", | |
"image_url", | |
"external_url", | |
"embedding", | |
"metadata", | |
"tags", | |
"created_at", | |
"updated_at", | |
], | |
) | |
# Move buffer cursor to start | |
buffer.seek(0) | |
# Execute COPY command with proper handling of special characters | |
with conn.cursor() as cur: | |
cur.copy_expert( | |
""" | |
COPY image ( | |
external_id, name, description, image_url, external_url, | |
embedding, metadata, tags, created_at, updated_at | |
) FROM STDIN WITH ( | |
FORMAT CSV, | |
DELIMITER E'\t', | |
NULL '\\N', | |
QUOTE '"', | |
ESCAPE '\\' | |
) | |
""", | |
buffer, | |
) | |
conn.commit() | |
logger.info( | |
{ | |
"event": "chunk_insert_completed", | |
"chunk_size": len(chunk), | |
"start_index": chunk_start, | |
} | |
) | |
logger.info({"event": "bulk_insert_success", "total_records": len(df)}) | |
except Exception as e: | |
conn.rollback() | |
logger.error( | |
{ | |
"event": "bulk_insert_failed", | |
"error": str(e), | |
"error_type": type(e).__name__, | |
"traceback": traceback.format_exc(), | |
} | |
) | |
raise | |
def parse_bytes_json(value: Any, default_value: Any) -> Any: | |
"""Parse a bytes value as JSON, with a default value if parsing fails.""" | |
try: | |
if isinstance(value, bytes): | |
try: | |
return json.loads(value.decode("utf-8")) | |
except json.JSONDecodeError: | |
# If it's not valid JSON, try to parse it as a string | |
value_str = value.decode("utf-8") | |
# Convert Python-style dict string to valid JSON | |
value_str = value_str.replace( | |
"'", '"' | |
) # Replace single quotes with double quotes | |
value_str = value_str.replace( | |
"None", "null" | |
) # Replace Python None with JSON null | |
try: | |
return json.loads(value_str) | |
except json.JSONDecodeError: | |
return default_value | |
return value or default_value | |
except Exception: | |
return default_value | |
async def create_db_pool( | |
min_size: int = 2, max_size: int = DB_MAX_CONNECTIONS | |
) -> SimpleConnectionPool: | |
""" | |
Create a database connection pool with the specified size limits. | |
Args: | |
min_size: Minimum number of connections to keep in the pool | |
max_size: Maximum number of connections allowed in the pool | |
Returns: | |
SimpleConnectionPool: A configured database connection pool | |
""" | |
try: | |
logger.info( | |
{ | |
"event": "creating_db_pool", | |
"min_size": min_size, | |
"max_size": max_size, | |
"config": { | |
"host": DB_CONFIG["host"], | |
"port": DB_CONFIG["port"], | |
"dbname": DB_CONFIG["dbname"], | |
}, | |
} | |
) | |
pool = SimpleConnectionPool(minconn=min_size, maxconn=max_size, **DB_CONFIG) | |
# Test the pool by getting and returning a connection | |
conn = pool.getconn() | |
with conn.cursor() as cur: | |
cur.execute("SELECT version();") | |
version = cur.fetchone()[0] | |
logger.info( | |
{"event": "db_pool_created", "version": version, "pool_size": max_size} | |
) | |
pool.putconn(conn) | |
return pool | |
except Exception as e: | |
logger.error( | |
{ | |
"event": "db_pool_creation_failed", | |
"error": str(e), | |
"error_type": type(e).__name__, | |
"traceback": traceback.format_exc(), | |
} | |
) | |
raise | |
async def process_records_batch( | |
records: List[Dict], | |
s3_actor: S3Actor, | |
db_pool: SimpleConnectionPool, | |
batch_start_time: float, | |
) -> Tuple[int, int]: | |
"""Process a batch of records, uploading images to S3 and inserting data into PostgreSQL.""" | |
successful_uploads = 0 | |
failed_uploads = 0 | |
conn = None | |
try: | |
# Pre-process records to prepare for bulk operations | |
processed_images = [] | |
upload_tasks = [] | |
for record in records: | |
external_id = str(record.get("id", "")) | |
try: | |
# Prepare upload task | |
image_url = record.get("image") | |
if not image_url: | |
failed_uploads += 1 | |
logger.warning( | |
{ | |
"event": "skip_record", | |
"reason": "missing_image_url", | |
"external_id": external_id, | |
"image_url": image_url, | |
} | |
) | |
continue | |
# Create upload task | |
upload_task = s3_actor.upload_image.remote(image_url, external_id) | |
upload_tasks.append((record, upload_task)) | |
except Exception as e: | |
failed_uploads += 1 | |
logger.error( | |
{ | |
"event": "record_processing_failed", | |
"error": str(e), | |
"external_id": external_id, | |
"error_type": type(e).__name__, | |
} | |
) | |
# Wait for all upload tasks to complete | |
for record, upload_task in upload_tasks: | |
try: | |
external_id = str(record.get("id", "")) | |
s3_url = await upload_task | |
if s3_url: | |
# Prepare record for database insertion | |
cafe_data = parse_bytes_json(record.get("cafe"), {}) | |
current_time = datetime.now(timezone.utc) | |
processed_images.append( | |
{ | |
"external_id": external_id, | |
"name": record.get("name", ""), | |
"description": record.get("description", ""), | |
"image_url": s3_url, | |
"external_url": image_url, | |
"embedding": parse_bytes_json(record.get("vector"), []), | |
"metadata": cafe_data, | |
"tags": cafe_data.get("categories", []), | |
"created_at": current_time, | |
"updated_at": current_time, | |
} | |
) | |
successful_uploads += 1 | |
else: | |
failed_uploads += 1 | |
except Exception as e: | |
failed_uploads += 1 | |
logger.error( | |
{ | |
"event": "upload_failed", | |
"error": str(e), | |
"external_id": external_id, | |
"error_type": type(e).__name__, | |
} | |
) | |
# Bulk insert processed records | |
if processed_images: | |
conn = db_pool.getconn() | |
conn.set_isolation_level(ISOLATION_LEVEL_READ_COMMITTED) | |
bulk_insert_to_db(conn, processed_images) | |
# Log batch metrics | |
batch_duration = time.time() - batch_start_time | |
total_records = len(records) | |
logger.info( | |
{ | |
"event": "batch_completed", | |
"successful_uploads": successful_uploads, | |
"failed_uploads": failed_uploads, | |
"batch_size": total_records, | |
"duration_seconds": batch_duration, | |
"records_per_second": total_records / batch_duration | |
if batch_duration > 0 | |
else 0, | |
} | |
) | |
except Exception as e: | |
total_records = len(records) | |
logger.error( | |
{ | |
"event": "batch_processing_failed", | |
"error": str(e), | |
"error_type": type(e).__name__, | |
"traceback": traceback.format_exc(), | |
"batch_size": total_records, | |
} | |
) | |
failed_uploads = total_records | |
successful_uploads = 0 | |
finally: | |
if conn: | |
db_pool.putconn(conn) | |
return successful_uploads, failed_uploads | |
async def main(): | |
"""Main entry point for the ingestion script.""" | |
start_time = time.time() | |
total_processed = 0 | |
total_successful = 0 | |
total_failed = 0 | |
s3_actors = [] | |
try: | |
# Initialize Ray if not already running | |
if not ray.is_initialized(): | |
ray.init() | |
# Create S3 actors pool | |
s3_actors = [S3Actor.remote() for _ in range(NUM_S3_ACTORS)] | |
current_actor = 0 | |
# Create database connection pool | |
db_pool = await create_db_pool(min_size=2, max_size=DB_MAX_CONNECTIONS) | |
# Load and process dataset | |
dataset = load_dataset( | |
"Qdrant/wolt-food-clip-ViT-B-32-embeddings", split="train", streaming=True | |
) | |
# Calculate optimal batch size | |
batch_size = BATCH_SIZE | |
logger.info( | |
{ | |
"event": "processing_started", | |
"batch_size": batch_size, | |
"num_s3_actors": NUM_S3_ACTORS, | |
"max_records": MAX_RECORDS, | |
"streaming": True, | |
} | |
) | |
# Process records in batches | |
current_batch = [] | |
batch_start_time = time.time() | |
for record in dataset: | |
# Check if we've reached or would exceed the maximum number of records | |
if total_successful >= MAX_RECORDS: | |
logger.info( | |
{ | |
"event": "max_records_reached", | |
"total_processed": total_processed, | |
"total_successful": total_successful, | |
"total_failed": total_failed, | |
"max_records": MAX_RECORDS, | |
} | |
) | |
break | |
# Adjust batch size if we're approaching MAX_RECORDS | |
remaining = MAX_RECORDS - total_successful | |
if len(current_batch) >= min(batch_size, remaining): | |
# Process batch using round-robin S3 actor assignment | |
s3_actor = s3_actors[current_actor] | |
current_actor = (current_actor + 1) % NUM_S3_ACTORS | |
successful, failed = await process_records_batch( | |
current_batch, s3_actor, db_pool, batch_start_time | |
) | |
total_successful += successful | |
total_failed += failed | |
total_processed += len(current_batch) | |
# Reset for next batch | |
current_batch = [] | |
batch_start_time = time.time() | |
# Log progress | |
elapsed_time = time.time() - start_time | |
logger.info( | |
{ | |
"event": "progress_update", | |
"total_processed": total_processed, | |
"total_successful": total_successful, | |
"total_failed": total_failed, | |
"remaining": MAX_RECORDS - total_successful, | |
"elapsed_minutes": elapsed_time / 60, | |
"records_per_second": total_processed / elapsed_time | |
if elapsed_time > 0 | |
else 0, | |
} | |
) | |
# Break if we've reached the limit | |
if total_successful >= MAX_RECORDS: | |
break | |
current_batch.append(record) | |
# Process remaining records only if we haven't reached MAX_RECORDS | |
if current_batch and total_successful < MAX_RECORDS: | |
s3_actor = s3_actors[current_actor] | |
successful, failed = await process_records_batch( | |
current_batch, s3_actor, db_pool, batch_start_time | |
) | |
total_successful += successful | |
total_failed += failed | |
total_processed += len(current_batch) | |
# Log final statistics | |
total_time = time.time() - start_time | |
logger.info( | |
{ | |
"event": "processing_completed", | |
"total_processed": total_processed, | |
"total_successful": total_successful, | |
"total_failed": total_failed, | |
"max_records": MAX_RECORDS, | |
"total_minutes": total_time / 60, | |
"average_records_per_second": total_processed / total_time | |
if total_time > 0 | |
else 0, | |
} | |
) | |
except Exception as e: | |
logger.error( | |
{ | |
"event": "processing_failed", | |
"error": str(e), | |
"error_type": type(e).__name__, | |
"traceback": traceback.format_exc(), | |
} | |
) | |
raise | |
finally: | |
# Cleanup S3 actors | |
cleanup_tasks = [actor.cleanup.remote() for actor in s3_actors] | |
try: | |
await asyncio.gather(*cleanup_tasks) | |
logger.info( | |
{"event": "cleanup_completed", "actors_cleaned": len(s3_actors)} | |
) | |
except Exception as e: | |
logger.error( | |
{ | |
"event": "cleanup_failed", | |
"error": str(e), | |
"error_type": type(e).__name__, | |
} | |
) | |
# Shutdown Ray | |
if ray.is_initialized(): | |
ray.shutdown() | |
# Close database connections | |
if "db_pool" in locals(): | |
db_pool.closeall() | |
if __name__ == "__main__": | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment