Created
March 18, 2026 12:10
-
-
Save recalde/7b6362a0ca00b9413d65ea35ac0586e7 to your computer and use it in GitHub Desktop.
repartition_batching.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| Build repartition manifest files in S3 from source parquet objects. | |
| What this program does | |
| ---------------------- | |
| 1. Uses an AWS named profile (default: AWS_PROFILE=default). | |
| 2. Discovers table names under: clean/<schema>/<table>/ | |
| 3. Lists every parquet object under each table using a ThreadPoolExecutor. | |
| 4. Sorts all source objects by (table_name, last_modified, key). | |
| 5. Packs source objects into manifest parts using ONLY source parquet file sizes. | |
| 6. Writes manifest parquet files back to the same bucket under: | |
| repartition/YYYYMMDD/part-00001.parquet | |
| 7. Optionally invokes a Lambda asynchronously after each manifest upload. | |
| Important behavior | |
| ------------------ | |
| - This program does NOT inspect or compare source parquet schemas. | |
| - Mixing tables in the same manifest part is allowed. | |
| - The size thresholds are based on the sum of SOURCE object sizes in each part, | |
| not the tiny manifest file size itself. | |
| - If a single source parquet file is larger than SOLO_OBJECT_THRESHOLD_BYTES, | |
| it is emitted in a manifest part by itself. | |
| Typical downstream pattern | |
| -------------------------- | |
| A downstream Lambda can consume each manifest part, inspect the listed source | |
| objects, and apply any table-specific or schema-specific processing rules. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import os | |
| import re | |
| import tempfile | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from dataclasses import dataclass | |
| from datetime import datetime, timezone | |
| from typing import Iterable, List, Optional | |
| import boto3 | |
| import pyarrow as pa | |
| import pyarrow.parquet as pq | |
| from botocore.config import Config | |
| from botocore.exceptions import BotoCoreError, ClientError | |
| # ============================================================================= | |
| # Configuration | |
| # ============================================================================= | |
| AWS_PROFILE = os.getenv("AWS_PROFILE", "default") | |
| AWS_REGION = os.getenv("AWS_REGION", "us-east-1") | |
| ENV = os.getenv("ENV", "dev") | |
| BUCKET_NAME = os.getenv("BUCKET_NAME", f"my-{ENV}-bucket") | |
| SCHEMA = os.getenv("SCHEMA", "my-schema") | |
| SOURCE_ROOT = os.getenv("SOURCE_ROOT", "clean") | |
| DEST_ROOT = os.getenv("DEST_ROOT", "repartition") | |
| RUN_DATE = os.getenv("RUN_DATE", datetime.now(timezone.utc).strftime("%Y%m%d")) | |
| DEST_PREFIX = os.getenv("DEST_PREFIX", f"{DEST_ROOT}/{RUN_DATE}") | |
| LIST_MAX_WORKERS = int(os.getenv("LIST_MAX_WORKERS", "10")) | |
| NOTIFY_MAX_WORKERS = int(os.getenv("NOTIFY_MAX_WORKERS", "10")) | |
| # Advisory minimum and hard maximum thresholds based on sum(source parquet file sizes) | |
| MIN_PART_TARGET_BYTES = int(os.getenv("MIN_PART_TARGET_BYTES", str(512 * 1024 * 1024))) | |
| MAX_PART_TARGET_BYTES = int(os.getenv("MAX_PART_TARGET_BYTES", str(1 * 1024 * 1024 * 1024))) | |
| SOLO_OBJECT_THRESHOLD_BYTES = int( | |
| os.getenv("SOLO_OBJECT_THRESHOLD_BYTES", str(MAX_PART_TARGET_BYTES)) | |
| ) | |
| MANIFEST_ROW_GROUP_SIZE = int(os.getenv("MANIFEST_ROW_GROUP_SIZE", "100000")) | |
| MANIFEST_COMPRESSION = os.getenv("MANIFEST_COMPRESSION", "snappy") | |
| ENABLE_ASYNC_LAMBDA_NOTIFY = os.getenv("ENABLE_ASYNC_LAMBDA_NOTIFY", "false").lower() == "true" | |
| NOTIFY_LAMBDA_NAME = os.getenv("NOTIFY_LAMBDA_NAME", "") | |
| NOTIFY_EXTRA_JSON = os.getenv("NOTIFY_EXTRA_JSON", "{}") | |
| LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() | |
| # ============================================================================= | |
| # Logging | |
| # ============================================================================= | |
| logging.basicConfig( | |
| level=getattr(logging, LOG_LEVEL, logging.INFO), | |
| format="%(asctime)s %(levelname)s %(message)s", | |
| ) | |
| logger = logging.getLogger("s3-repartitioner") | |
| # ============================================================================= | |
| # Data classes | |
| # ============================================================================= | |
| @dataclass(frozen=True) | |
| class SourceObject: | |
| bucket: str | |
| key: str | |
| table_name: str | |
| last_modified: datetime | |
| size_bytes: int | |
| @dataclass | |
| class PartPlan: | |
| part_number: int | |
| objects: List[SourceObject] | |
| total_source_bytes: int | |
| @dataclass | |
| class WriteResult: | |
| bucket: str | |
| key: str | |
| part_number: int | |
| row_count: int | |
| total_source_bytes: int | |
| # ============================================================================= | |
| # AWS helpers | |
| # ============================================================================= | |
| def build_session() -> boto3.session.Session: | |
| return boto3.Session(profile_name=AWS_PROFILE, region_name=AWS_REGION) | |
| _SESSION = build_session() | |
| _S3 = _SESSION.client( | |
| "s3", | |
| config=Config( | |
| retries={"max_attempts": 10, "mode": "standard"}, | |
| max_pool_connections=max(32, LIST_MAX_WORKERS + NOTIFY_MAX_WORKERS + 10), | |
| ), | |
| ) | |
| _LAMBDA = _SESSION.client( | |
| "lambda", | |
| config=Config( | |
| retries={"max_attempts": 10, "mode": "standard"}, | |
| max_pool_connections=max(16, NOTIFY_MAX_WORKERS + 4), | |
| ), | |
| ) | |
| # ============================================================================= | |
| # S3 discovery | |
| # ============================================================================= | |
| def schema_prefix() -> str: | |
| return f"{SOURCE_ROOT}/{SCHEMA}/" | |
| def discover_table_names(bucket: str, prefix: str) -> List[str]: | |
| """ | |
| Discover immediate table-name folders under clean/<schema>/ using Delimiter='/'. | |
| """ | |
| logger.info("Discovering tables under s3://%s/%s", bucket, prefix) | |
| paginator = _S3.get_paginator("list_objects_v2") | |
| tables: List[str] = [] | |
| for page in paginator.paginate(Bucket=bucket, Prefix=prefix, Delimiter="/"): | |
| for cp in page.get("CommonPrefixes", []): | |
| child_prefix = cp.get("Prefix", "") | |
| table_name = child_prefix.rstrip("/").split("/")[-1] | |
| if table_name: | |
| tables.append(table_name) | |
| tables = sorted(set(tables)) | |
| logger.info("Discovered %d tables", len(tables)) | |
| return tables | |
| _PARTITION_RE = re.compile( | |
| r"^(?P<root>[^/]+)/(?P<schema>[^/]+)/(?P<table>[^/]+)/" | |
| r"(?P<year>\d{4})/(?P<month>\d{2})/(?P<day>\d{2})/(?P<hour>\d{2})/[^/]+\.parquet$", | |
| flags=re.IGNORECASE, | |
| ) | |
| def list_table_objects(bucket: str, table_name: str) -> List[SourceObject]: | |
| """ | |
| List every parquet object under a single table prefix. | |
| """ | |
| prefix = f"{SOURCE_ROOT}/{SCHEMA}/{table_name}/" | |
| logger.info("Listing parquet objects under s3://%s/%s", bucket, prefix) | |
| paginator = _S3.get_paginator("list_objects_v2") | |
| found: List[SourceObject] = [] | |
| for page in paginator.paginate(Bucket=bucket, Prefix=prefix): | |
| for obj in page.get("Contents", []): | |
| key = obj["Key"] | |
| if not key.lower().endswith(".parquet"): | |
| continue | |
| # Keep only keys matching the expected partition layout. | |
| if not _PARTITION_RE.match(key): | |
| continue | |
| found.append( | |
| SourceObject( | |
| bucket=bucket, | |
| key=key, | |
| table_name=table_name, | |
| last_modified=obj["LastModified"], | |
| size_bytes=int(obj["Size"]), | |
| ) | |
| ) | |
| logger.info("Found %d parquet objects for table %s", len(found), table_name) | |
| return found | |
| # ============================================================================= | |
| # Packing | |
| # ============================================================================= | |
| def pack_objects_into_parts(objects: List[SourceObject]) -> List[PartPlan]: | |
| """ | |
| Pack objects into manifest parts using ONLY source object sizes. | |
| Rules: | |
| - Objects are assumed to already be globally sorted. | |
| - If an object's size >= SOLO_OBJECT_THRESHOLD_BYTES, emit it alone. | |
| - Otherwise, accumulate until adding the next object would exceed MAX_PART_TARGET_BYTES. | |
| - MIN_PART_TARGET_BYTES is advisory only; the hard stop is MAX_PART_TARGET_BYTES. | |
| """ | |
| parts: List[PartPlan] = [] | |
| current: List[SourceObject] = [] | |
| current_bytes = 0 | |
| next_part_number = 1 | |
| def flush_current() -> None: | |
| nonlocal current, current_bytes, next_part_number | |
| if not current: | |
| return | |
| parts.append( | |
| PartPlan( | |
| part_number=next_part_number, | |
| objects=current, | |
| total_source_bytes=current_bytes, | |
| ) | |
| ) | |
| next_part_number += 1 | |
| current = [] | |
| current_bytes = 0 | |
| for obj in objects: | |
| # Very large source object: force into its own manifest part. | |
| if obj.size_bytes >= SOLO_OBJECT_THRESHOLD_BYTES: | |
| flush_current() | |
| parts.append( | |
| PartPlan( | |
| part_number=next_part_number, | |
| objects=[obj], | |
| total_source_bytes=obj.size_bytes, | |
| ) | |
| ) | |
| next_part_number += 1 | |
| continue | |
| if not current: | |
| current = [obj] | |
| current_bytes = obj.size_bytes | |
| continue | |
| if current_bytes + obj.size_bytes <= MAX_PART_TARGET_BYTES: | |
| current.append(obj) | |
| current_bytes += obj.size_bytes | |
| continue | |
| # Adding this object would exceed the hard cap. | |
| flush_current() | |
| current = [obj] | |
| current_bytes = obj.size_bytes | |
| flush_current() | |
| return parts | |
| # ============================================================================= | |
| # Manifest writing | |
| # ============================================================================= | |
| def manifest_key_for_part(part_number: int) -> str: | |
| return f"{DEST_PREFIX}/part-{part_number:05d}.parquet" | |
| MANIFEST_SCHEMA = pa.schema( | |
| [ | |
| pa.field("source_bucket", pa.string(), nullable=False), | |
| pa.field("source_key", pa.string(), nullable=False), | |
| pa.field("table_name", pa.string(), nullable=False), | |
| pa.field("source_last_modified_utc", pa.timestamp("us", tz="UTC"), nullable=False), | |
| pa.field("source_size_bytes", pa.int64(), nullable=False), | |
| ] | |
| ) | |
| def build_manifest_table(plan: PartPlan) -> pa.Table: | |
| rows = [] | |
| for obj in plan.objects: | |
| rows.append( | |
| { | |
| "source_bucket": obj.bucket, | |
| "source_key": obj.key, | |
| "table_name": obj.table_name, | |
| "source_last_modified_utc": obj.last_modified.astimezone(timezone.utc), | |
| "source_size_bytes": obj.size_bytes, | |
| } | |
| ) | |
| return pa.Table.from_pylist(rows, schema=MANIFEST_SCHEMA) | |
| def write_manifest_part(bucket: str, plan: PartPlan) -> WriteResult: | |
| """ | |
| Write one manifest parquet file and upload it to S3. | |
| """ | |
| manifest_key = manifest_key_for_part(plan.part_number) | |
| table = build_manifest_table(plan) | |
| with tempfile.NamedTemporaryFile(prefix=f"part-{plan.part_number:05d}-", suffix=".parquet", delete=True) as tmp: | |
| pq.write_table( | |
| table, | |
| tmp.name, | |
| compression=MANIFEST_COMPRESSION, | |
| row_group_size=MANIFEST_ROW_GROUP_SIZE, | |
| ) | |
| _S3.upload_file(tmp.name, bucket, manifest_key) | |
| logger.info( | |
| "Wrote manifest s3://%s/%s with %d rows covering %d source bytes", | |
| bucket, | |
| manifest_key, | |
| table.num_rows, | |
| plan.total_source_bytes, | |
| ) | |
| return WriteResult( | |
| bucket=bucket, | |
| key=manifest_key, | |
| part_number=plan.part_number, | |
| row_count=table.num_rows, | |
| total_source_bytes=plan.total_source_bytes, | |
| ) | |
| # ============================================================================= | |
| # Optional notify Lambda | |
| # ============================================================================= | |
| def build_notify_payload(result: WriteResult) -> dict: | |
| payload = { | |
| "eventType": "s3_repartition_manifest_created", | |
| "runDate": RUN_DATE, | |
| "bucket": result.bucket, | |
| "key": result.key, | |
| "partNumber": result.part_number, | |
| "rowCount": result.row_count, | |
| "totalSourceBytes": result.total_source_bytes, | |
| "schema": SCHEMA, | |
| "sourceRoot": SOURCE_ROOT, | |
| "destRoot": DEST_ROOT, | |
| } | |
| if NOTIFY_EXTRA_JSON.strip(): | |
| try: | |
| extra = json.loads(NOTIFY_EXTRA_JSON) | |
| if isinstance(extra, dict): | |
| payload.update(extra) | |
| except json.JSONDecodeError: | |
| logger.warning("Ignoring invalid NOTIFY_EXTRA_JSON") | |
| return payload | |
| def invoke_notify_lambda_async(result: WriteResult) -> None: | |
| if not ENABLE_ASYNC_LAMBDA_NOTIFY: | |
| return | |
| if not NOTIFY_LAMBDA_NAME: | |
| raise ValueError("ENABLE_ASYNC_LAMBDA_NOTIFY=true but NOTIFY_LAMBDA_NAME is empty") | |
| payload = build_notify_payload(result) | |
| _LAMBDA.invoke( | |
| FunctionName=NOTIFY_LAMBDA_NAME, | |
| InvocationType="Event", | |
| Payload=json.dumps(payload).encode("utf-8"), | |
| ) | |
| logger.info("Invoked Lambda asynchronously for s3://%s/%s", result.bucket, result.key) | |
| # ============================================================================= | |
| # Main flow | |
| # ============================================================================= | |
| def collect_all_source_objects(bucket: str, tables: Iterable[str]) -> List[SourceObject]: | |
| all_objects: List[SourceObject] = [] | |
| with ThreadPoolExecutor(max_workers=LIST_MAX_WORKERS) as executor: | |
| future_to_table = {executor.submit(list_table_objects, bucket, table): table for table in tables} | |
| for future in as_completed(future_to_table): | |
| table_name = future_to_table[future] | |
| try: | |
| all_objects.extend(future.result()) | |
| except Exception as exc: | |
| logger.exception("Failed listing table %s: %s", table_name, exc) | |
| raise | |
| all_objects.sort(key=lambda o: (o.table_name, o.last_modified, o.key)) | |
| return all_objects | |
| def log_plan(parts: List[PartPlan]) -> None: | |
| total_objects = sum(len(p.objects) for p in parts) | |
| total_bytes = sum(p.total_source_bytes for p in parts) | |
| logger.info("Planned %d manifest parts", len(parts)) | |
| logger.info("Total source objects: %d", total_objects) | |
| logger.info("Total source bytes: %d", total_bytes) | |
| for plan in parts[:20]: | |
| logger.info( | |
| "Plan part-%05d: rows=%d source_bytes=%d first=%s last=%s", | |
| plan.part_number, | |
| len(plan.objects), | |
| plan.total_source_bytes, | |
| plan.objects[0].key if plan.objects else "", | |
| plan.objects[-1].key if plan.objects else "", | |
| ) | |
| if len(parts) > 20: | |
| logger.info("... plan logging truncated after first 20 parts ...") | |
| def main() -> int: | |
| try: | |
| logger.info("Starting S3 repartition manifest builder") | |
| logger.info("AWS_PROFILE=%s AWS_REGION=%s", AWS_PROFILE, AWS_REGION) | |
| logger.info("Bucket=%s Schema=%s SourcePrefix=%s", BUCKET_NAME, SCHEMA, schema_prefix()) | |
| logger.info( | |
| "Thresholds: min=%d max=%d solo=%d", | |
| MIN_PART_TARGET_BYTES, | |
| MAX_PART_TARGET_BYTES, | |
| SOLO_OBJECT_THRESHOLD_BYTES, | |
| ) | |
| tables = discover_table_names(BUCKET_NAME, schema_prefix()) | |
| if not tables: | |
| logger.warning("No tables found under s3://%s/%s", BUCKET_NAME, schema_prefix()) | |
| return 0 | |
| all_objects = collect_all_source_objects(BUCKET_NAME, tables) | |
| if not all_objects: | |
| logger.warning("No parquet objects found under s3://%s/%s", BUCKET_NAME, schema_prefix()) | |
| return 0 | |
| logger.info("Collected %d parquet source objects", len(all_objects)) | |
| parts = pack_objects_into_parts(all_objects) | |
| log_plan(parts) | |
| results: List[WriteResult] = [] | |
| for plan in parts: | |
| results.append(write_manifest_part(BUCKET_NAME, plan)) | |
| if ENABLE_ASYNC_LAMBDA_NOTIFY: | |
| logger.info( | |
| "Async Lambda notify enabled for %d manifest files with max_workers=%d", | |
| len(results), | |
| NOTIFY_MAX_WORKERS, | |
| ) | |
| with ThreadPoolExecutor(max_workers=NOTIFY_MAX_WORKERS) as executor: | |
| futures = [executor.submit(invoke_notify_lambda_async, result) for result in results] | |
| for future in as_completed(futures): | |
| future.result() | |
| logger.info("Completed successfully. Wrote %d manifest files.", len(results)) | |
| return 0 | |
| except (BotoCoreError, ClientError, ValueError) as exc: | |
| logger.exception("Fatal AWS/configuration error: %s", exc) | |
| return 2 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment