Skip to content

Instantly share code, notes, and snippets.

@stas00
Created August 6, 2021 19:27
Show Gist options
  • Save stas00/dc1597a1e245c5915cfeefa0eee6902c to your computer and use it in GitHub Desktop.
Save stas00/dc1597a1e245c5915cfeefa0eee6902c to your computer and use it in GitHub Desktop.
# benchmark datasets' to_json:
# - normal
# - multiproc version
# - sharded multiproc version
import time
from datasets import load_dataset
import pathlib
import os
from pathlib import Path
import shutil
import gc
from multiprocessing import cpu_count, Process, Queue
batch_sizes = [10_000, 50_000, 100_000, 125_000]
num_procs = [1, 4] # change this according to your machine
batch_sizes = [10_000, 100_000]
DATASET_NAME = "lama"
num_shards = [1, 4]
# benchmark sharded version
for batch in batch_sizes:
for shards in num_shards:
local_start = time.time()
dataset = load_dataset(DATASET_NAME)["train"]
#print(dataset)
def process_shard(idx):
print(f"Sharding {idx}")
if shards > 1:
ds_shard = dataset.shard(shards, idx, contiguous=True)
else:
ds_shard = dataset
# ds_shard = ds_shard.shuffle() # remove contiguous=True above if shuffling
print(f"Saving {DATASET_NAME}-{idx}.jsonl")
ds_shard.to_json(f"{DATASET_NAME}-{idx}.jsonl", batch_size=batch, orient="records", lines=True, force_ascii=False)
#process_shard(0)
queue = Queue()
processes = [Process(target=process_shard, args=(idx,)) for idx in range(shards)]
for p in processes:
p.start()
for p in processes:
p.join()
local_end = time.time() - local_start
print(f"Time taken on {shards} shards and {batch} batch_size: ", local_end)
# benchmark multiproc to_json version
SAVE_LOC = "./new_dataset.json"
for batch in batch_sizes:
for num in num_procs:
dataset = load_dataset(DATASET_NAME)
local_start = time.time()
ans = dataset['train'].to_json(SAVE_LOC, batch_size=batch, num_proc=num, orient="records", lines=True, force_ascii=False)
local_end = time.time() - local_start
print(f"Time taken on {num} num_proc and {batch} batch_size: ", local_end)
# remove that dataset and its contents from cache and newly generated json
new_json = pathlib.Path(SAVE_LOC)
new_json.unlink()
@stas00
Copy link
Author

stas00 commented Aug 6, 2021

Results on a small dataset "lama" probably need to use a larger dataset to get more realistic numbers:

No config specified, defaulting to: lama/trex
Reusing dataset lama (/gpfswork/rech/six/commun/datasets/lama/trex/1.1.0/e6c2daf22ebc3a92f694a68c7dd45c99b85637a65001e28faf6a99a8401bebf3)
Sharding 0
Saving lama-0.jsonl
Creating json from Arrow format: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 131/131 [00:09<00:00, 14.01ba/s]
Time taken on 1 shards and 10000 batch_size:  10.352300643920898
No config specified, defaulting to: lama/trex
Reusing dataset lama (/gpfswork/rech/six/commun/datasets/lama/trex/1.1.0/e6c2daf22ebc3a92f694a68c7dd45c99b85637a65001e28faf6a99a8401bebf3)
Sharding 0
Sharding 1
Sharding 2
Sharding 3
Saving lama-0.jsonl
Saving lama-1.jsonl
Creating json from Arrow format:   0%|                                                                                                    | 0/33 [00:00<?, ?ba/s]Saving lama-2.jsonl
Creating json from Arrow format:   0%|                                                                                                    | 0/33 [00:00<?, ?ba/s]Saving lama-3.jsonl
Creating json from Arrow format: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 33/33 [00:09<00:00,  3.38ba/s]
Creating json from Arrow format: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 33/33 [00:12<00:00,  2.57ba/s]
Creating json from Arrow format: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 33/33 [00:20<00:00,  1.59ba/s]
Creating json from Arrow format: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 33/33 [00:21<00:00,  1.51ba/s]
Time taken on 4 shards and 10000 batch_size:  22.517746210098267
No config specified, defaulting to: lama/trex
Reusing dataset lama (/gpfswork/rech/six/commun/datasets/lama/trex/1.1.0/e6c2daf22ebc3a92f694a68c7dd45c99b85637a65001e28faf6a99a8401bebf3)
Sharding 0
Saving lama-0.jsonl
Creating json from Arrow format: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:10<00:00,  1.29ba/s]
Time taken on 1 shards and 100000 batch_size:  11.42315673828125
No config specified, defaulting to: lama/trex
Reusing dataset lama (/gpfswork/rech/six/commun/datasets/lama/trex/1.1.0/e6c2daf22ebc3a92f694a68c7dd45c99b85637a65001e28faf6a99a8401bebf3)
Sharding 0
Sharding 1
Sharding 2
Sharding 3
Saving lama-0.jsonl
Saving lama-1.jsonl
Creating json from Arrow format:   0%|                                                                                                     | 0/4 [00:00<?, ?ba/s]Saving lama-2.jsonl
Creating json from Arrow format:   0%|                                                                                                     | 0/4 [00:00<?, ?ba/s]Saving lama-3.jsonl
Creating json from Arrow format: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:10<00:00,  2.63s/ba]
Creating json from Arrow format: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:12<00:00,  3.02s/ba]
Creating json from Arrow format: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:20<00:00,  5.15s/ba]
Creating json from Arrow format: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:21<00:00,  5.31s/ba]
Time taken on 4 shards and 100000 batch_size:  22.530734539031982
No config specified, defaulting to: lama/trex
Reusing dataset lama (/gpfswork/rech/six/commun/datasets/lama/trex/1.1.0/e6c2daf22ebc3a92f694a68c7dd45c99b85637a65001e28faf6a99a8401bebf3)
Creating json from Arrow format: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 131/131 [00:09<00:00, 14.10ba/s]
Time taken on 1 num_proc and 10000 batch_size:  9.298408508300781
No config specified, defaulting to: lama/trex
Reusing dataset lama (/gpfswork/rech/six/commun/datasets/lama/trex/1.1.0/e6c2daf22ebc3a92f694a68c7dd45c99b85637a65001e28faf6a99a8401bebf3)
Creating json from Arrow format: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 131/131 [00:04<00:00, 32.45ba/s]
Time taken on 4 num_proc and 10000 batch_size:  4.837364435195923
No config specified, defaulting to: lama/trex
Reusing dataset lama (/gpfswork/rech/six/commun/datasets/lama/trex/1.1.0/e6c2daf22ebc3a92f694a68c7dd45c99b85637a65001e28faf6a99a8401bebf3)
Creating json from Arrow format: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:10<00:00,  1.30ba/s]
Time taken on 1 num_proc and 100000 batch_size:  10.799099445343018
No config specified, defaulting to: lama/trex
Reusing dataset lama (/gpfswork/rech/six/commun/datasets/lama/trex/1.1.0/e6c2daf22ebc3a92f694a68c7dd45c99b85637a65001e28faf6a99a8401bebf3)
Creating json from Arrow format: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 14/14 [00:04<00:00,  3.39ba/s]
Time taken on 4 num_proc and 100000 batch_size:  4.747768402099609

same log w/o the noise:

Time taken on 1 shards and 10000 batch_size:  10.352300643920898
Time taken on 4 shards and 10000 batch_size:  22.517746210098267
Time taken on 1 shards and 100000 batch_size:  11.42315673828125
Time taken on 4 shards and 100000 batch_size:  22.530734539031982
Time taken on 1 num_proc and 10000 batch_size:  9.298408508300781
Time taken on 4 num_proc and 10000 batch_size:  4.837364435195923
Time taken on 1 num_proc and 100000 batch_size:  10.799099445343018
Time taken on 4 num_proc and 100000 batch_size:  4.747768402099609

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment