Skip to content

Instantly share code, notes, and snippets.

@maxidl
Created August 28, 2025 14:57
Show Gist options
  • Save maxidl/a7f1c118d8470844b23006d5af2df08c to your computer and use it in GitHub Desktop.
Save maxidl/a7f1c118d8470844b23006d5af2df08c to your computer and use it in GitHub Desktop.
re-shard parquet dataset
from pathlib import Path
import pyarrow.parquet as pq
import pyarrow.dataset as pads
from tqdm.auto import tqdm
d_in = Path("path-to-input-dataset")
d_out = Path("path-to-write-sharded-dataset-to")
num_shards = # specify the desired number of output files (shards) here
d_out.mkdir(exist_ok=False, parents=True)
data_files = sorted(list(d_in.glob("**/*.parquet")))
print(f"Found {len(data_files)} files in {d_in}")
ds = pads.dataset(d_in)
num_rows = ds.count_rows()
print(f"Number of rows in dataset: {num_rows:_}")
num_rows_per_shard = num_rows // num_shards
print(f"Number of rows per output shard: {num_rows_per_shard:_}")
schema = ds.schema
batches_iter = ds.to_batches(batch_size=10_000)
def create_writer():
writer = pq.ParquetWriter(temp_output_file, schema, compression="zstd")
return writer
writer = None
shard=0
temp_output_file = d_out / "shard.incomplete"
rows_written = 0
with tqdm(total=num_rows, desc="Sharding dataset") as pbar:
for batch in batches_iter:
if writer is None:
writer = create_writer()
offset_in_batch = 0
remaining_in_batch = batch.num_rows
# Write the batch in slices so we never exceed the shard size
while remaining_in_batch > 0:
remaining_in_shard = num_rows_per_shard - rows_written if shard < num_shards - 1 else remaining_in_batch
rows_to_write = min(remaining_in_batch, remaining_in_shard)
if rows_to_write > 0:
slice_to_write = batch.slice(offset_in_batch, rows_to_write)
writer.write_batch(slice_to_write)
rows_written += rows_to_write
offset_in_batch += rows_to_write
remaining_in_batch -= rows_to_write
pbar.update(rows_to_write)
# If current shard is filled (for all but the last shard), roll to next shard
if (shard < num_shards - 1) and (rows_written == num_rows_per_shard):
writer.close()
temp_output_file.rename(d_out / f"shard_{shard:06d}.parquet")
shard += 1
writer = create_writer()
rows_written = 0
if writer is not None:
writer.close()
temp_output_file.rename(d_out / f"shard_{shard:06d}.parquet")
ds = pads.dataset(d_out)
print(f"Number of rows in output dataset: {ds.count_rows():_}")
print(f"Number of output files: {len(list(ds.get_fragments()))}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment