Last active
March 17, 2022 23:50
-
-
Save stephanie-wang/138557a42968ede41cd16592ca5df2d3 to your computer and use it in GitHub Desktop.
Sort benchmark for datasets
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ray | |
import pandas as pd | |
import numpy as np | |
import time | |
import builtins | |
from typing import Any, Generic, List, Callable, Union, Tuple, Iterable | |
import os | |
import psutil | |
import resource | |
import numpy as np | |
import ray | |
from ray.types import ObjectRef | |
from ray.data.block import ( | |
Block, | |
BlockAccessor, | |
BlockMetadata, | |
T, | |
BlockPartition, | |
BlockPartitionMetadata, | |
MaybeBlockPartition, | |
) | |
from ray.data.context import DatasetContext | |
from ray.data.impl.arrow_block import ArrowRow | |
from ray.data.impl.delegating_block_builder import DelegatingBlockBuilder | |
from ray.data.impl.util import _check_pyarrow_version | |
from ray.util.annotations import DeveloperAPI | |
from ray.data.datasource import Datasource, ReadTask | |
from ray.internal.internal_api import memory_summary | |
class RandomIntRowDatasource(Datasource[ArrowRow]): | |
"""An example datasource that generates rows with random int64 columns. | |
Examples: | |
>>> source = RandomIntRowDatasource() | |
>>> ray.data.read_datasource(source, n=10, num_columns=2).take() | |
... {'c_0': 1717767200176864416, 'c_1': 999657309586757214} | |
... {'c_0': 4983608804013926748, 'c_1': 1160140066899844087} | |
""" | |
def prepare_read( | |
self, parallelism: int, n: int, num_columns: int | |
) -> List[ReadTask]: | |
_check_pyarrow_version() | |
import pyarrow | |
read_tasks: List[ReadTask] = [] | |
block_size = max(1, n // parallelism) | |
def make_block(count: int, num_columns: int) -> Block: | |
return pyarrow.Table.from_arrays( | |
np.random.randint( | |
np.iinfo(np.int64).max, size=(num_columns, count), dtype=np.int64 | |
), | |
names=[f"c_{i}" for i in range(num_columns)], | |
) | |
schema = pyarrow.Table.from_pydict( | |
{f"c_{i}": [0] for i in range(num_columns)} | |
).schema | |
i = 0 | |
while i < n: | |
count = min(block_size, n - i) | |
meta = BlockMetadata( | |
num_rows=count, | |
size_bytes=8 * count * num_columns, | |
schema=schema, | |
input_files=None, | |
exec_stats=None, | |
) | |
read_tasks.append( | |
ReadTask( | |
lambda count=count, num_columns=num_columns: [ | |
make_block(count, num_columns) | |
], | |
meta, | |
) | |
) | |
i += block_size | |
return read_tasks | |
def memory_bytes(): | |
maxrss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss * 1e3 | |
process = psutil.Process(os.getpid()) | |
rss = process.memory_info().rss | |
return [ | |
("max", maxrss), | |
("rss", rss), | |
] | |
if __name__ == '__main__': | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--num-partitions", help="number of partitions", default="50", type=str | |
) | |
parser.add_argument( | |
"--partition-size", | |
help="partition size (bytes)", | |
default="200e6", | |
type=str, | |
) | |
parser.add_argument( | |
"--shuffle", help="shuffle instead of sort", action="store_true" | |
) | |
args = parser.parse_args() | |
num_partitions = int(args.num_partitions) | |
partition_size = int(float(args.partition_size)) | |
print(f"Dataset size: {num_partitions} partitions, {partition_size / 1e9}GB partition size, {num_partitions * partition_size / 1e9}GB total") | |
start = time.time() | |
source = RandomIntRowDatasource() | |
num_rows_per_partition = partition_size // 8 | |
ds = ray.data.read_datasource(source, | |
parallelism=num_partitions, | |
n=num_rows_per_partition * num_partitions, | |
num_columns=1) | |
exc = None | |
try: | |
if args.shuffle: | |
ds = ds.random_shuffle() | |
else: | |
ds = ds.sort(key="c_0") | |
except Exception as e: | |
exc = e | |
pass | |
end = time.time() | |
print("Finished in", end - start) | |
print("") | |
print("==== Driver memory summary ====") | |
mem = memory_bytes() | |
for key, val in mem: | |
print(f"{key}: {val / 1e9}GB") | |
print("") | |
print(memory_summary(stats_only=True)) | |
if exc: | |
raise exc |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment