-
-
Save GaelVaroquaux/0d1a73810fee41171aee57a31a31e86a to your computer and use it in GitHub Desktop.
Persistence strategies comparison
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Persistence strategies comparison script. | |
This script compute the speed, memory used and disk space used when dumping and | |
loading arbitrary data. The data are taken among: | |
- scikit-learn Labeled Faces in the Wild dataset (LFW) | |
- a fully random numpy array with 10000x10000 shape | |
- a dictionary with 1M random keys/values | |
- a list containing 10M random value | |
The compared persistence strategies are: | |
- joblib | |
- joblib compressed (using zlib at compress level 3) | |
- pickle | |
- numpy (using savez/load functions) | |
- numpy compressed (using savez_compressed/load functions) | |
""" | |
import os | |
import argparse | |
import shutil | |
import time | |
import numpy as np | |
import joblib | |
import pickle | |
import pandas as pd | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
from memory_profiler import memory_usage | |
from sklearn import datasets | |
from joblib.disk import disk_used | |
DATASET_DESC = {'lfw_people': 'Labeled Faces in the Wild dataset (LFW)', | |
'big_array': 'Numpy array with random values (~700MB)', | |
'big_dict': 'Dictionary with 1M random keys/values', | |
'big_list': 'List of 10M random values'} | |
def clear_output_directory(): | |
"""Remove generated output directory.""" | |
if os.path.exists('out'): | |
shutil.rmtree('out') | |
os.mkdir('out') | |
def kill_disk_cache(): | |
"""Remove computation bias introduced by disk caching mecanism.""" | |
if os.name == 'posix' and os.uname()[0] == 'Linux': | |
try: | |
os.system('sudo sh -c "sync; echo 3 > /proc/sys/vm/drop_caches"') | |
except IOError as e: | |
if e.errno == 13: | |
print('Please run me as root') | |
else: | |
raise e | |
else: | |
# Write ~100M to the disk | |
with open('tmp', 'w') as f: | |
f.write(np.random.random(2e7)) | |
def timeit(func, *args, **kwargs): | |
"""Compute the execution time of func.""" | |
kill_disk_cache() | |
t0 = time.time() | |
func(*args, **kwargs) | |
t1 = time.time() | |
return t1 - t0 | |
def memit(func, *args, **kwargs): | |
"""Compute memory usage of func.""" | |
mem_use = memory_usage((func, args, kwargs), interval=.001) | |
return max(mem_use) - min(mem_use) | |
def write_to_file(fileobj, strategy, dataset, write_time, read_time, mem_write, | |
mem_read, disk_used): | |
string = "{0},{1},{2:.3f},{3:.3f},{4:.1f},{5:.1f},{6:.1f}\n".format( | |
strategy, dataset, write_time, read_time, mem_write, mem_read, | |
disk_used) | |
fileobj.write(string) | |
def print_line(strategy, dataset, write_time, read_time, mem_write, mem_read, | |
disk_used): | |
"""Nice printing function.""" | |
print('%30s, %10s, % 9.3f, % 9.3f, % 9.1f, % 9.1f, % 5.1f' % ( | |
strategy, dataset, write_time, read_time, mem_write, mem_read, | |
disk_used)) | |
def generate_dataset(dataset_str): | |
"""Generate requested dataset.""" | |
if dataset_str == 'lfw_people': | |
dataset = datasets.fetch_lfw_people() | |
elif dataset_str == 'big_array': | |
# Generate random seed | |
rnd = np.random.RandomState(0) | |
dataset = rnd.random_sample((10000, 10000)) | |
elif dataset_str == 'big_dict': | |
dataset = {} | |
rnd = np.random.RandomState(0) | |
randoms = rnd.random_sample((1000000)) | |
for key, random in zip(range(1000000), randoms): | |
dataset[str(key)] = int(random) | |
elif dataset_str == 'big_list': | |
dataset = [] | |
rnd = np.random.RandomState(0) | |
for random in rnd.random_sample((10000000)): | |
dataset.append(int(random)) | |
else: | |
return None # should not happen | |
return dataset | |
def run_joblib_bench(filename, obj, strategy, dataset, output_file, **kwargs): | |
"""Bench joblib functions.""" | |
time_write = time_read = du = mem_read = mem_write = [] | |
clear_output_directory() | |
time_write = timeit(joblib.dump, obj, filename, **kwargs) | |
mem_write = memit(joblib.dump, obj, filename, **kwargs) | |
du = disk_used('out') / 1024. | |
time_read = timeit(joblib.load, filename) | |
mem_read = memit(joblib.load, filename) | |
print_line(strategy, dataset, | |
time_write, time_read, mem_write, mem_read, du) | |
write_to_file(output_file, strategy, dataset, | |
time_write, time_read, mem_write, mem_read, du) | |
def run_numpy_bench(filename, obj, strategy, dataset, output_file): | |
"""Bench numpy functions.""" | |
time_write = time_read = du = mem_read = mem_write = [] | |
clear_output_directory() | |
time_write = timeit(np.savez, filename, obj) | |
mem_write = memit(np.savez, filename, obj) | |
du = disk_used('out') / 1024. | |
with np.load(filename + '.npz') as npz: | |
time_read = timeit(npz.items) | |
with np.load(filename + '.npz') as npz: | |
mem_read = memit(npz.items) | |
print_line(strategy, dataset, | |
time_write, time_read, mem_write, mem_read, du) | |
write_to_file(output_file, strategy, dataset, | |
time_write, time_read, mem_write, mem_read, du) | |
def run_numpy_compressed_bench(filename, obj, strategy, dataset, output_file): | |
"""Bench numpy compressed functions.""" | |
time_write = time_read = du = mem_read = mem_write = [] | |
clear_output_directory() | |
time_write = timeit(np.savez_compressed, filename, obj) | |
mem_write = memit(np.savez_compressed, filename, obj) | |
du = disk_used('out') / 1024. | |
with np.load(filename + '.npz') as npz: | |
time_read = timeit(npz.items) | |
with np.load(filename + '.npz') as npz: | |
mem_read = memit(npz.items) | |
print_line(strategy, dataset, | |
time_write, time_read, mem_write, mem_read, du) | |
write_to_file(output_file, strategy, dataset, | |
time_write, time_read, mem_write, mem_read, du) | |
def run_pickle_bench(filename, obj, strategy, dataset, output_file): | |
"""Bench pickle functions.""" | |
time_write = time_read = du = mem_read = mem_write = [] | |
clear_output_directory() | |
with open(filename, 'wb') as f: | |
time_write = timeit(pickle.dump, obj, f) | |
with open(filename, 'wb') as f: | |
mem_write = memit(pickle.dump, obj, f) | |
du = disk_used('out') / 1024. | |
with open(filename, 'rb') as f: | |
time_read = timeit(pickle.load, f) | |
with open(filename, 'rb') as f: | |
mem_read = memit(pickle.load, f) | |
print_line(strategy, dataset, | |
time_write, time_read, mem_write, mem_read, du) | |
write_to_file(output_file, strategy, dataset, | |
time_write, time_read, mem_write, mem_read, du) | |
def generate_plots(filename, dataset): | |
"""Generate a nice matplotlib figure.""" | |
sns.set(style="whitegrid", context="talk") | |
df = pd.read_csv(filename) | |
df = df[df.dataset == dataset] # filter on dataset | |
if not len(df): | |
print("Nothing to plot, exiting") | |
return | |
# Set up the matplotlib figure | |
f, (dump_axe, load_axe, mem_dump_axe, mem_load_axe, disk_axe) = \ | |
plt.subplots(1, 5, figsize=(9, 4.2), sharey=True, | |
gridspec_kw=dict(wspace=.7, right=.947, bottom=.005, | |
top=.85, left=.14)) | |
df.strategy = [s.replace(' ', '\n', 1).replace('0.10.0.dev0', | |
'dev').replace(' -', ',') | |
for s in df.strategy] | |
strategies = df.strategy | |
dump_times = df.dump | |
load_times = df.load | |
memory_dump = df.mem_dump | |
memory_load = df.mem_load | |
disk_used = df.disk | |
plt.text(.005, .96, '{0}'.format(DATASET_DESC[dataset]), size=13, | |
transform=f.transFigure) | |
sns.barplot(dump_times, strategies, palette="Set3", ax=dump_axe) | |
dump_axe.set_title("Dump time") | |
dump_axe.set_xlabel("") | |
dump_axe.set_ylabel("") | |
for i, v in enumerate(strategies.unique()): | |
value = df[df.strategy == v].dump.mean() | |
dump_axe.text(value + 0.01 * max(dump_times), | |
i + .15, "{0:.2G}s".format(value), | |
color='black', style='italic') | |
dump_axe.set_xticks(()) | |
dump_axe.set_yticklabels(list(strategies), size=14) | |
sns.barplot(load_times, strategies, palette="Set3", ax=load_axe) | |
load_axe.set_title("Load time") | |
load_axe.set_xlabel("") | |
load_axe.set_ylabel("") | |
for i, v in enumerate(strategies.unique()): | |
value = df[df.strategy == v].load.mean() | |
load_axe.text(value + 0.01 * max(load_times), | |
i + .15, "{0:.2G}s".format(value), | |
color='black', style='italic') | |
load_axe.set_xticks(()) | |
sns.barplot(memory_dump, strategies, palette="Set3", ax=mem_dump_axe) | |
mem_dump_axe.set_xlabel("") | |
mem_dump_axe.set_title("Memory used\nwith dump") | |
mem_dump_axe.set_ylabel("") | |
for i, v in enumerate(strategies.unique()): | |
value = df[df.strategy == v].mem_dump.mean() | |
mem_dump_axe.text(value + 0.01 * max(memory_dump), | |
i + .15, "{0:.0f}MB".format(value), | |
color='black', style='italic') | |
mem_dump_axe.set_xticks(()) | |
sns.barplot(memory_load, strategies, palette="Set3", ax=mem_load_axe) | |
mem_load_axe.set_title("Memory used\nwith load") | |
mem_load_axe.set_xlabel("") | |
mem_load_axe.set_ylabel("") | |
for i, v in enumerate(strategies.unique()): | |
value = df[df.strategy == v].mem_load.mean() | |
mem_load_axe.text(value + 0.01 * max(memory_load), | |
i + .15, "{0:.0f}MB".format(value), | |
color='black', style='italic') | |
mem_load_axe.set_xticks(()) | |
sns.barplot(disk_used, strategies, palette="Set3", ax=disk_axe) | |
disk_axe.set_title("Disk used") | |
disk_axe.set_xlabel("") | |
disk_axe.xaxis.tick_top() | |
disk_axe.set_ylabel("") | |
for i, v in enumerate(strategies.unique()): | |
value = df[df.strategy == v].disk.mean() | |
disk_axe.text(value + 0.01 * max(disk_used), | |
i + .15, "{0:.0f}MB".format(value), | |
color='black', style='italic') | |
disk_axe.set_xticks(()) | |
sns.despine(bottom=True) | |
plt.savefig('tmp.png', dpi=100) | |
#plt.show() | |
def parse_args(): | |
"""Parse command line arguments.""" | |
parser = argparse.ArgumentParser("Persistence strategies benchmark script") | |
parser.add_argument('--tries', type=int, default=1, | |
help="Number of tries on each strategy.") | |
parser.add_argument('--all', action='store_true', | |
help='Perform all benchmarks in a row (joblib, numpy, ' | |
'pickle, joblib compressed, numpy compressed).') | |
parser.add_argument('--joblib', action='store_true', | |
help='Run joblib persistence in raw file') | |
parser.add_argument('--joblib-compressed', action='store_true', | |
help='Run joblib persistence in compressed file') | |
parser.add_argument('--pickle', action='store_true', | |
help='Run pickle persistence in raw file') | |
parser.add_argument('--numpy', action='store_true', | |
help='Run numpy persistence in raw file') | |
parser.add_argument('--numpy-compressed', action='store_true', | |
help='Run numpy persistence in compressed file') | |
parser.add_argument('--filename', type=str, | |
default='/tmp/persistence_bench.csv', | |
help='Output csv file where benchmark results are ' | |
'stored') | |
parser.add_argument('--append-results', action='store_true', | |
help='Append results at the end of csv file.') | |
parser.add_argument('--plot', action='store_true', | |
help='Display results from the csv file in nice ' | |
'plots.') | |
parser.add_argument('--dataset', type=str, default='lfw_people', | |
choices=['lfw_people', 'big_array', 'big_dict', | |
'big_list'], | |
help="Dataset to bench.") | |
return parser.parse_args() | |
def main(): | |
"""Main function.""" | |
args = parse_args() | |
# Generating requested dataset | |
dataset = generate_dataset(args.dataset) | |
if (args.joblib or args.joblib_compressed or | |
args.pickle or args.numpy or args.numpy_compressed): | |
header_str = '%30s, %10s, % 9s, % 9s, % 9s, % 9s, % 5s' % ( | |
'strategy', 'dataset', 'dump', 'load', 'mem_dump', 'mem_load', | |
'disk') | |
print(header_str) | |
# when using plot, we force append_results if the file is not empty | |
# to avoid clearing available results. | |
if args.plot: | |
args.append_results = True | |
# If the output csv file doesn't exist, we force the header string to be | |
# written. | |
if not os.path.exists(args.filename): | |
args.append_results = False | |
with open(args.filename, 'a' if args.append_results else 'w') as f: | |
if not args.append_results: | |
f.write(header_str.replace(' ', '') + '\n') | |
for _ in range(args.tries): | |
if args.joblib or args.all: | |
run_joblib_bench('out/test.pkl', dataset, | |
'joblib (%s)' % joblib.__version__, | |
args.dataset, f) | |
if args.joblib_compressed or args.all: | |
run_joblib_bench('out/test.pkl', dataset, | |
'joblib (%s - zlib 3)' % joblib.__version__, | |
args.dataset, f, compress=3) | |
if args.pickle or args.all: | |
run_pickle_bench('out/test.pkl', dataset, 'pickle', | |
args.dataset, f) | |
if args.numpy or args.all: | |
run_numpy_bench('out/test.pkl', dataset, 'numpy', | |
args.dataset, f) | |
if args.numpy_compressed or args.all: | |
run_numpy_compressed_bench('out/test.pkl', dataset, | |
'numpy compressed', | |
args.dataset, f) | |
# Generate the matplotlib figure | |
if args.plot: | |
generate_plots(args.filename, args.dataset) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment