Skip to content

Instantly share code, notes, and snippets.

@stephanie-wang
Created December 10, 2020 23:19
Show Gist options
  • Save stephanie-wang/842669be51a263e5637c0e5ba0236132 to your computer and use it in GitHub Desktop.
Save stephanie-wang/842669be51a263e5637c0e5ba0236132 to your computer and use it in GitHub Desktop.
Dask sort
import dask
import dask.dataframe as dd
import json
import pandas as pd
import numpy as np
import os.path
import csv
import fastparquet
from dask.distributed import Client
import time
def load_dataset(client, nbytes, npartitions):
num_bytes_per_partition = nbytes // npartitions
filenames = []
def foo(i):
filename = "df-{}-{}.parquet.gzip".format(num_bytes_per_partition, i)
print("Partition file", filename)
if not os.path.exists(filename):
nrows = num_bytes_per_partition // 8
print("Allocating dataset with {} rows".format(nrows))
dataset = pd.DataFrame(np.random.randint(0, np.iinfo(np.int64).max, size=(nrows, 1), dtype=np.int64), columns=['a'])
print("Done allocating")
dataset.to_parquet(filename, compression='gzip')
print("Done writing to disk")
return filename
for i in range(npartitions):
filenames.append(client.submit(foo, i))
filenames = client.gather(filenames)
df = dd.read_parquet(filenames)
return df
def trial(client, nbytes, n_partitions):
df = load_dataset(client, nbytes, n_partitions)
times = []
start = time.time()
for i in range(10):
print("Trial {} start".format(i))
trial_start = time.time()
df.set_index('a').compute()
trial_end = time.time()
duration = trial_end - trial_start
times.append(duration)
print("Trial {} done after {}".format(i, duration))
if time.time() - start > 60:
break
return times
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--nbytes", type=int, default=1_000_000)
parser.add_argument("--npartitions", type=int, default=100, required=False)
# Max partition size is 1GB.
parser.add_argument("--max-partition-size", type=int, default=1000_000_000, required=False)
parser.add_argument("--use-tasks", action="store_true")
args = parser.parse_args()
if args.use_tasks:
dask.config.set(shuffle='tasks')
client = Client('127.0.0.1:8786')
print("dask", trial(client, 1000, 10))
print("WARMUP DONE")
npartitions = args.npartitions
if args.nbytes // npartitions > args.max_partition_size:
npartitions = args.nbytes // args.max_partition_size
output = trial(client, args.nbytes, npartitions)
print("mean over {} trials: {} +- {}".format(len(output), np.mean(output), np.std(output)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment