Created
August 6, 2021 19:27
-
-
Save stas00/dc1597a1e245c5915cfeefa0eee6902c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# benchmark datasets' to_json: | |
# - normal | |
# - multiproc version | |
# - sharded multiproc version | |
import time | |
from datasets import load_dataset | |
import pathlib | |
import os | |
from pathlib import Path | |
import shutil | |
import gc | |
from multiprocessing import cpu_count, Process, Queue | |
batch_sizes = [10_000, 50_000, 100_000, 125_000] | |
num_procs = [1, 4] # change this according to your machine | |
batch_sizes = [10_000, 100_000] | |
DATASET_NAME = "lama" | |
num_shards = [1, 4] | |
# benchmark sharded version | |
for batch in batch_sizes: | |
for shards in num_shards: | |
local_start = time.time() | |
dataset = load_dataset(DATASET_NAME)["train"] | |
#print(dataset) | |
def process_shard(idx): | |
print(f"Sharding {idx}") | |
if shards > 1: | |
ds_shard = dataset.shard(shards, idx, contiguous=True) | |
else: | |
ds_shard = dataset | |
# ds_shard = ds_shard.shuffle() # remove contiguous=True above if shuffling | |
print(f"Saving {DATASET_NAME}-{idx}.jsonl") | |
ds_shard.to_json(f"{DATASET_NAME}-{idx}.jsonl", batch_size=batch, orient="records", lines=True, force_ascii=False) | |
#process_shard(0) | |
queue = Queue() | |
processes = [Process(target=process_shard, args=(idx,)) for idx in range(shards)] | |
for p in processes: | |
p.start() | |
for p in processes: | |
p.join() | |
local_end = time.time() - local_start | |
print(f"Time taken on {shards} shards and {batch} batch_size: ", local_end) | |
# benchmark multiproc to_json version | |
SAVE_LOC = "./new_dataset.json" | |
for batch in batch_sizes: | |
for num in num_procs: | |
dataset = load_dataset(DATASET_NAME) | |
local_start = time.time() | |
ans = dataset['train'].to_json(SAVE_LOC, batch_size=batch, num_proc=num, orient="records", lines=True, force_ascii=False) | |
local_end = time.time() - local_start | |
print(f"Time taken on {num} num_proc and {batch} batch_size: ", local_end) | |
# remove that dataset and its contents from cache and newly generated json | |
new_json = pathlib.Path(SAVE_LOC) | |
new_json.unlink() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Results on a small dataset "lama" probably need to use a larger dataset to get more realistic numbers:
same log w/o the noise: