Skip to content

Instantly share code, notes, and snippets.

@recalde
Created March 18, 2026 12:10
Show Gist options
  • Select an option

  • Save recalde/7b6362a0ca00b9413d65ea35ac0586e7 to your computer and use it in GitHub Desktop.

Select an option

Save recalde/7b6362a0ca00b9413d65ea35ac0586e7 to your computer and use it in GitHub Desktop.
repartition_batching.py
#!/usr/bin/env python3
"""
Build repartition manifest files in S3 from source parquet objects.
What this program does
----------------------
1. Uses an AWS named profile (default: AWS_PROFILE=default).
2. Discovers table names under: clean/<schema>/<table>/
3. Lists every parquet object under each table using a ThreadPoolExecutor.
4. Sorts all source objects by (table_name, last_modified, key).
5. Packs source objects into manifest parts using ONLY source parquet file sizes.
6. Writes manifest parquet files back to the same bucket under:
repartition/YYYYMMDD/part-00001.parquet
7. Optionally invokes a Lambda asynchronously after each manifest upload.
Important behavior
------------------
- This program does NOT inspect or compare source parquet schemas.
- Mixing tables in the same manifest part is allowed.
- The size thresholds are based on the sum of SOURCE object sizes in each part,
not the tiny manifest file size itself.
- If a single source parquet file is larger than SOLO_OBJECT_THRESHOLD_BYTES,
it is emitted in a manifest part by itself.
Typical downstream pattern
--------------------------
A downstream Lambda can consume each manifest part, inspect the listed source
objects, and apply any table-specific or schema-specific processing rules.
"""
from __future__ import annotations
import json
import logging
import os
import re
import tempfile
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Iterable, List, Optional
import boto3
import pyarrow as pa
import pyarrow.parquet as pq
from botocore.config import Config
from botocore.exceptions import BotoCoreError, ClientError
# =============================================================================
# Configuration
# =============================================================================
AWS_PROFILE = os.getenv("AWS_PROFILE", "default")
AWS_REGION = os.getenv("AWS_REGION", "us-east-1")
ENV = os.getenv("ENV", "dev")
BUCKET_NAME = os.getenv("BUCKET_NAME", f"my-{ENV}-bucket")
SCHEMA = os.getenv("SCHEMA", "my-schema")
SOURCE_ROOT = os.getenv("SOURCE_ROOT", "clean")
DEST_ROOT = os.getenv("DEST_ROOT", "repartition")
RUN_DATE = os.getenv("RUN_DATE", datetime.now(timezone.utc).strftime("%Y%m%d"))
DEST_PREFIX = os.getenv("DEST_PREFIX", f"{DEST_ROOT}/{RUN_DATE}")
LIST_MAX_WORKERS = int(os.getenv("LIST_MAX_WORKERS", "10"))
NOTIFY_MAX_WORKERS = int(os.getenv("NOTIFY_MAX_WORKERS", "10"))
# Advisory minimum and hard maximum thresholds based on sum(source parquet file sizes)
MIN_PART_TARGET_BYTES = int(os.getenv("MIN_PART_TARGET_BYTES", str(512 * 1024 * 1024)))
MAX_PART_TARGET_BYTES = int(os.getenv("MAX_PART_TARGET_BYTES", str(1 * 1024 * 1024 * 1024)))
SOLO_OBJECT_THRESHOLD_BYTES = int(
os.getenv("SOLO_OBJECT_THRESHOLD_BYTES", str(MAX_PART_TARGET_BYTES))
)
MANIFEST_ROW_GROUP_SIZE = int(os.getenv("MANIFEST_ROW_GROUP_SIZE", "100000"))
MANIFEST_COMPRESSION = os.getenv("MANIFEST_COMPRESSION", "snappy")
ENABLE_ASYNC_LAMBDA_NOTIFY = os.getenv("ENABLE_ASYNC_LAMBDA_NOTIFY", "false").lower() == "true"
NOTIFY_LAMBDA_NAME = os.getenv("NOTIFY_LAMBDA_NAME", "")
NOTIFY_EXTRA_JSON = os.getenv("NOTIFY_EXTRA_JSON", "{}")
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
# =============================================================================
# Logging
# =============================================================================
logging.basicConfig(
level=getattr(logging, LOG_LEVEL, logging.INFO),
format="%(asctime)s %(levelname)s %(message)s",
)
logger = logging.getLogger("s3-repartitioner")
# =============================================================================
# Data classes
# =============================================================================
@dataclass(frozen=True)
class SourceObject:
bucket: str
key: str
table_name: str
last_modified: datetime
size_bytes: int
@dataclass
class PartPlan:
part_number: int
objects: List[SourceObject]
total_source_bytes: int
@dataclass
class WriteResult:
bucket: str
key: str
part_number: int
row_count: int
total_source_bytes: int
# =============================================================================
# AWS helpers
# =============================================================================
def build_session() -> boto3.session.Session:
return boto3.Session(profile_name=AWS_PROFILE, region_name=AWS_REGION)
_SESSION = build_session()
_S3 = _SESSION.client(
"s3",
config=Config(
retries={"max_attempts": 10, "mode": "standard"},
max_pool_connections=max(32, LIST_MAX_WORKERS + NOTIFY_MAX_WORKERS + 10),
),
)
_LAMBDA = _SESSION.client(
"lambda",
config=Config(
retries={"max_attempts": 10, "mode": "standard"},
max_pool_connections=max(16, NOTIFY_MAX_WORKERS + 4),
),
)
# =============================================================================
# S3 discovery
# =============================================================================
def schema_prefix() -> str:
return f"{SOURCE_ROOT}/{SCHEMA}/"
def discover_table_names(bucket: str, prefix: str) -> List[str]:
"""
Discover immediate table-name folders under clean/<schema>/ using Delimiter='/'.
"""
logger.info("Discovering tables under s3://%s/%s", bucket, prefix)
paginator = _S3.get_paginator("list_objects_v2")
tables: List[str] = []
for page in paginator.paginate(Bucket=bucket, Prefix=prefix, Delimiter="/"):
for cp in page.get("CommonPrefixes", []):
child_prefix = cp.get("Prefix", "")
table_name = child_prefix.rstrip("/").split("/")[-1]
if table_name:
tables.append(table_name)
tables = sorted(set(tables))
logger.info("Discovered %d tables", len(tables))
return tables
_PARTITION_RE = re.compile(
r"^(?P<root>[^/]+)/(?P<schema>[^/]+)/(?P<table>[^/]+)/"
r"(?P<year>\d{4})/(?P<month>\d{2})/(?P<day>\d{2})/(?P<hour>\d{2})/[^/]+\.parquet$",
flags=re.IGNORECASE,
)
def list_table_objects(bucket: str, table_name: str) -> List[SourceObject]:
"""
List every parquet object under a single table prefix.
"""
prefix = f"{SOURCE_ROOT}/{SCHEMA}/{table_name}/"
logger.info("Listing parquet objects under s3://%s/%s", bucket, prefix)
paginator = _S3.get_paginator("list_objects_v2")
found: List[SourceObject] = []
for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
for obj in page.get("Contents", []):
key = obj["Key"]
if not key.lower().endswith(".parquet"):
continue
# Keep only keys matching the expected partition layout.
if not _PARTITION_RE.match(key):
continue
found.append(
SourceObject(
bucket=bucket,
key=key,
table_name=table_name,
last_modified=obj["LastModified"],
size_bytes=int(obj["Size"]),
)
)
logger.info("Found %d parquet objects for table %s", len(found), table_name)
return found
# =============================================================================
# Packing
# =============================================================================
def pack_objects_into_parts(objects: List[SourceObject]) -> List[PartPlan]:
"""
Pack objects into manifest parts using ONLY source object sizes.
Rules:
- Objects are assumed to already be globally sorted.
- If an object's size >= SOLO_OBJECT_THRESHOLD_BYTES, emit it alone.
- Otherwise, accumulate until adding the next object would exceed MAX_PART_TARGET_BYTES.
- MIN_PART_TARGET_BYTES is advisory only; the hard stop is MAX_PART_TARGET_BYTES.
"""
parts: List[PartPlan] = []
current: List[SourceObject] = []
current_bytes = 0
next_part_number = 1
def flush_current() -> None:
nonlocal current, current_bytes, next_part_number
if not current:
return
parts.append(
PartPlan(
part_number=next_part_number,
objects=current,
total_source_bytes=current_bytes,
)
)
next_part_number += 1
current = []
current_bytes = 0
for obj in objects:
# Very large source object: force into its own manifest part.
if obj.size_bytes >= SOLO_OBJECT_THRESHOLD_BYTES:
flush_current()
parts.append(
PartPlan(
part_number=next_part_number,
objects=[obj],
total_source_bytes=obj.size_bytes,
)
)
next_part_number += 1
continue
if not current:
current = [obj]
current_bytes = obj.size_bytes
continue
if current_bytes + obj.size_bytes <= MAX_PART_TARGET_BYTES:
current.append(obj)
current_bytes += obj.size_bytes
continue
# Adding this object would exceed the hard cap.
flush_current()
current = [obj]
current_bytes = obj.size_bytes
flush_current()
return parts
# =============================================================================
# Manifest writing
# =============================================================================
def manifest_key_for_part(part_number: int) -> str:
return f"{DEST_PREFIX}/part-{part_number:05d}.parquet"
MANIFEST_SCHEMA = pa.schema(
[
pa.field("source_bucket", pa.string(), nullable=False),
pa.field("source_key", pa.string(), nullable=False),
pa.field("table_name", pa.string(), nullable=False),
pa.field("source_last_modified_utc", pa.timestamp("us", tz="UTC"), nullable=False),
pa.field("source_size_bytes", pa.int64(), nullable=False),
]
)
def build_manifest_table(plan: PartPlan) -> pa.Table:
rows = []
for obj in plan.objects:
rows.append(
{
"source_bucket": obj.bucket,
"source_key": obj.key,
"table_name": obj.table_name,
"source_last_modified_utc": obj.last_modified.astimezone(timezone.utc),
"source_size_bytes": obj.size_bytes,
}
)
return pa.Table.from_pylist(rows, schema=MANIFEST_SCHEMA)
def write_manifest_part(bucket: str, plan: PartPlan) -> WriteResult:
"""
Write one manifest parquet file and upload it to S3.
"""
manifest_key = manifest_key_for_part(plan.part_number)
table = build_manifest_table(plan)
with tempfile.NamedTemporaryFile(prefix=f"part-{plan.part_number:05d}-", suffix=".parquet", delete=True) as tmp:
pq.write_table(
table,
tmp.name,
compression=MANIFEST_COMPRESSION,
row_group_size=MANIFEST_ROW_GROUP_SIZE,
)
_S3.upload_file(tmp.name, bucket, manifest_key)
logger.info(
"Wrote manifest s3://%s/%s with %d rows covering %d source bytes",
bucket,
manifest_key,
table.num_rows,
plan.total_source_bytes,
)
return WriteResult(
bucket=bucket,
key=manifest_key,
part_number=plan.part_number,
row_count=table.num_rows,
total_source_bytes=plan.total_source_bytes,
)
# =============================================================================
# Optional notify Lambda
# =============================================================================
def build_notify_payload(result: WriteResult) -> dict:
payload = {
"eventType": "s3_repartition_manifest_created",
"runDate": RUN_DATE,
"bucket": result.bucket,
"key": result.key,
"partNumber": result.part_number,
"rowCount": result.row_count,
"totalSourceBytes": result.total_source_bytes,
"schema": SCHEMA,
"sourceRoot": SOURCE_ROOT,
"destRoot": DEST_ROOT,
}
if NOTIFY_EXTRA_JSON.strip():
try:
extra = json.loads(NOTIFY_EXTRA_JSON)
if isinstance(extra, dict):
payload.update(extra)
except json.JSONDecodeError:
logger.warning("Ignoring invalid NOTIFY_EXTRA_JSON")
return payload
def invoke_notify_lambda_async(result: WriteResult) -> None:
if not ENABLE_ASYNC_LAMBDA_NOTIFY:
return
if not NOTIFY_LAMBDA_NAME:
raise ValueError("ENABLE_ASYNC_LAMBDA_NOTIFY=true but NOTIFY_LAMBDA_NAME is empty")
payload = build_notify_payload(result)
_LAMBDA.invoke(
FunctionName=NOTIFY_LAMBDA_NAME,
InvocationType="Event",
Payload=json.dumps(payload).encode("utf-8"),
)
logger.info("Invoked Lambda asynchronously for s3://%s/%s", result.bucket, result.key)
# =============================================================================
# Main flow
# =============================================================================
def collect_all_source_objects(bucket: str, tables: Iterable[str]) -> List[SourceObject]:
all_objects: List[SourceObject] = []
with ThreadPoolExecutor(max_workers=LIST_MAX_WORKERS) as executor:
future_to_table = {executor.submit(list_table_objects, bucket, table): table for table in tables}
for future in as_completed(future_to_table):
table_name = future_to_table[future]
try:
all_objects.extend(future.result())
except Exception as exc:
logger.exception("Failed listing table %s: %s", table_name, exc)
raise
all_objects.sort(key=lambda o: (o.table_name, o.last_modified, o.key))
return all_objects
def log_plan(parts: List[PartPlan]) -> None:
total_objects = sum(len(p.objects) for p in parts)
total_bytes = sum(p.total_source_bytes for p in parts)
logger.info("Planned %d manifest parts", len(parts))
logger.info("Total source objects: %d", total_objects)
logger.info("Total source bytes: %d", total_bytes)
for plan in parts[:20]:
logger.info(
"Plan part-%05d: rows=%d source_bytes=%d first=%s last=%s",
plan.part_number,
len(plan.objects),
plan.total_source_bytes,
plan.objects[0].key if plan.objects else "",
plan.objects[-1].key if plan.objects else "",
)
if len(parts) > 20:
logger.info("... plan logging truncated after first 20 parts ...")
def main() -> int:
try:
logger.info("Starting S3 repartition manifest builder")
logger.info("AWS_PROFILE=%s AWS_REGION=%s", AWS_PROFILE, AWS_REGION)
logger.info("Bucket=%s Schema=%s SourcePrefix=%s", BUCKET_NAME, SCHEMA, schema_prefix())
logger.info(
"Thresholds: min=%d max=%d solo=%d",
MIN_PART_TARGET_BYTES,
MAX_PART_TARGET_BYTES,
SOLO_OBJECT_THRESHOLD_BYTES,
)
tables = discover_table_names(BUCKET_NAME, schema_prefix())
if not tables:
logger.warning("No tables found under s3://%s/%s", BUCKET_NAME, schema_prefix())
return 0
all_objects = collect_all_source_objects(BUCKET_NAME, tables)
if not all_objects:
logger.warning("No parquet objects found under s3://%s/%s", BUCKET_NAME, schema_prefix())
return 0
logger.info("Collected %d parquet source objects", len(all_objects))
parts = pack_objects_into_parts(all_objects)
log_plan(parts)
results: List[WriteResult] = []
for plan in parts:
results.append(write_manifest_part(BUCKET_NAME, plan))
if ENABLE_ASYNC_LAMBDA_NOTIFY:
logger.info(
"Async Lambda notify enabled for %d manifest files with max_workers=%d",
len(results),
NOTIFY_MAX_WORKERS,
)
with ThreadPoolExecutor(max_workers=NOTIFY_MAX_WORKERS) as executor:
futures = [executor.submit(invoke_notify_lambda_async, result) for result in results]
for future in as_completed(futures):
future.result()
logger.info("Completed successfully. Wrote %d manifest files.", len(results))
return 0
except (BotoCoreError, ClientError, ValueError) as exc:
logger.exception("Fatal AWS/configuration error: %s", exc)
return 2
if __name__ == "__main__":
raise SystemExit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment