Last active
September 17, 2020 17:36
-
-
Save aabadie/2ba94d28d68f19f87eb8916a2238a97c to your computer and use it in GitHub Desktop.
Persistence strategies comparison
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Persistence strategies comparison script. | |
This script compute the speed, memory used and disk space used when dumping and | |
loading arbitrary data. The data are taken among: | |
- scikit-learn Labeled Faces in the Wild dataset (LFW) | |
- a fully random numpy array with 10000x10000 shape | |
- a dictionary with 1M random keys/values | |
- a list containing 10M random value | |
The compared persistence strategies are: | |
- joblib | |
- joblib compressed (using zlib at compress level 3) | |
- pickle | |
- numpy (using savez/load functions) | |
- numpy compressed (using savez_compressed/load functions) | |
""" | |
import os | |
import shutil | |
import time | |
import numpy as np | |
import joblib | |
import pickle | |
from memory_profiler import memory_usage | |
from sklearn import datasets | |
from joblib.disk import disk_used | |
# Script configuration variables: | |
CSV_FILE = '/tmp/comparison_results.csv' | |
STRATEGIES = ['joblib', 'pickle', 'numpy', | |
'joblib-compressed', 'numpy-compressed'] | |
DATASET = 'lfw_people' | |
TRIES = 1 | |
SHOW_PLOT = True | |
############################################################################### | |
# Helper functions | |
def clear_output_directory(): | |
"""Remove generated output directory.""" | |
if os.path.exists('out'): | |
shutil.rmtree('out') | |
os.mkdir('out') | |
def kill_disk_cache(): | |
"""Remove computation bias introduced by disk caching mecanism.""" | |
if os.name == 'posix' and os.uname()[0] == 'Linux': | |
try: | |
os.system('sudo sh -c "sync; echo 3 > /proc/sys/vm/drop_caches"') | |
except IOError as e: | |
if e.errno == 13: | |
print('Please run me as root') | |
else: | |
raise e | |
else: | |
# Write ~100M to the disk | |
with open('tmp', 'w') as f: | |
f.write(np.random.random(2e7)) | |
def timeit(func, *args, **kwargs): | |
"""Compute the execution time of func.""" | |
kill_disk_cache() | |
t0 = time.time() | |
func(*args, **kwargs) | |
t1 = time.time() | |
return t1 - t0 | |
def memit(func, *args, **kwargs): | |
"""Compute memory usage of func.""" | |
mem_use = memory_usage((func, args, kwargs), interval=.001) | |
return max(mem_use) - min(mem_use) | |
def generate_dataset(dataset_str): | |
"""Generate requested dataset.""" | |
if dataset_str == 'lfw_people': | |
dataset = datasets.fetch_lfw_people() | |
elif dataset_str == 'big_array': | |
# Generate random seed | |
rnd = np.random.RandomState(0) | |
dataset = rnd.random_sample((10000, 10000)) | |
elif dataset_str == 'big_dict': | |
dataset = {} | |
rnd = np.random.RandomState(0) | |
randoms = rnd.random_sample((1000000)) | |
for key, random in zip(range(1000000), randoms): | |
dataset[str(key)] = int(random) | |
elif dataset_str == 'big_list': | |
dataset = [] | |
rnd = np.random.RandomState(0) | |
for random in rnd.random_sample((10000000)): | |
dataset.append(int(random)) | |
else: | |
return None # should not happen | |
return dataset | |
############################################################################### | |
# Bench results print/write functions | |
def write_to_file(fileobj, strategy, dataset, write_time, read_time, mem_write, | |
mem_read, disk_used): | |
"""Write results of a bench in a file.""" | |
string = "{0},{1},{2:.3f},{3:.3f},{4:.1f},{5:.1f},{6:.1f}\n".format( | |
strategy, dataset, write_time, read_time, mem_write, mem_read, | |
disk_used) | |
fileobj.write(string) | |
def print_line(strategy, dataset, write_time, read_time, mem_write, mem_read, | |
disk_used): | |
"""Nice printing function.""" | |
print('%30s, %10s, % 9.3f, % 9.3f, % 9.1f, % 9.1f, % 5.1f' % ( | |
strategy, dataset, write_time, read_time, mem_write, mem_read, | |
disk_used)) | |
############################################################################### | |
# Bench functions | |
def run_joblib_bench(filename, obj, strategy, dataset, output_file, **kwargs): | |
"""Bench joblib functions.""" | |
time_write = time_read = du = mem_read = mem_write = [] | |
clear_output_directory() | |
time_write = timeit(joblib.dump, obj, filename, **kwargs) | |
mem_write = memit(joblib.dump, obj, filename, **kwargs) | |
du = disk_used('out') / 1024. | |
time_read = timeit(joblib.load, filename) | |
mem_read = memit(joblib.load, filename) | |
print_line(strategy, dataset, | |
time_write, time_read, mem_write, mem_read, du) | |
write_to_file(output_file, strategy, dataset, | |
time_write, time_read, mem_write, mem_read, du) | |
def run_numpy_bench(filename, obj, strategy, dataset, output_file): | |
"""Bench numpy functions.""" | |
time_write = time_read = du = mem_read = mem_write = [] | |
clear_output_directory() | |
time_write = timeit(np.savez, filename, obj) | |
mem_write = memit(np.savez, filename, obj) | |
du = disk_used('out') / 1024. | |
with np.load(filename + '.npz') as npz: | |
time_read = timeit(npz.items) | |
with np.load(filename + '.npz') as npz: | |
mem_read = memit(npz.items) | |
print_line(strategy, dataset, | |
time_write, time_read, mem_write, mem_read, du) | |
write_to_file(output_file, strategy, dataset, | |
time_write, time_read, mem_write, mem_read, du) | |
def run_numpy_compressed_bench(filename, obj, strategy, dataset, output_file): | |
"""Bench numpy compressed functions.""" | |
time_write = time_read = du = mem_read = mem_write = [] | |
clear_output_directory() | |
time_write = timeit(np.savez_compressed, filename, obj) | |
mem_write = memit(np.savez_compressed, filename, obj) | |
du = disk_used('out') / 1024. | |
with np.load(filename + '.npz') as npz: | |
time_read = timeit(npz.items) | |
with np.load(filename + '.npz') as npz: | |
mem_read = memit(npz.items) | |
print_line(strategy, dataset, | |
time_write, time_read, mem_write, mem_read, du) | |
write_to_file(output_file, strategy, dataset, | |
time_write, time_read, mem_write, mem_read, du) | |
def run_pickle_bench(filename, obj, strategy, dataset, output_file): | |
"""Bench pickle functions.""" | |
time_write = time_read = du = mem_read = mem_write = [] | |
clear_output_directory() | |
with open(filename, 'wb') as f: | |
time_write = timeit(pickle.dump, obj, f) | |
with open(filename, 'wb') as f: | |
mem_write = memit(pickle.dump, obj, f) | |
du = disk_used('out') / 1024. | |
with open(filename, 'rb') as f: | |
time_read = timeit(pickle.load, f) | |
with open(filename, 'rb') as f: | |
mem_read = memit(pickle.load, f) | |
print_line(strategy, dataset, | |
time_write, time_read, mem_write, mem_read, du) | |
write_to_file(output_file, strategy, dataset, | |
time_write, time_read, mem_write, mem_read, du) | |
def bench(): | |
"""Main function.""" | |
# Generating requested dataset | |
dataset = generate_dataset(DATASET) | |
if len(STRATEGIES) != 0: | |
header_str = '%30s, %10s, % 9s, % 9s, % 9s, % 9s, % 5s' % ( | |
'strategy', 'dataset', 'dump', 'load', 'mem_dump', 'mem_load', | |
'disk') | |
print(header_str) | |
write_header = not os.path.exists(CSV_FILE) | |
with open(CSV_FILE, 'w' if write_header else 'a') as f: | |
if write_header: | |
f.write(header_str.replace(' ', '') + '\n') | |
for _ in range(TRIES): | |
if 'joblib' in STRATEGIES: | |
run_joblib_bench('out/test.pkl', dataset, | |
'joblib (%s)' % joblib.__version__, | |
DATASET, f) | |
if 'pickle' in STRATEGIES: | |
run_pickle_bench('out/test.pkl', dataset, 'pickle', | |
DATASET, f) | |
if 'numpy' in STRATEGIES: | |
run_numpy_bench('out/test.pkl', dataset, 'numpy', | |
DATASET, f) | |
if 'joblib-compressed' in STRATEGIES: | |
run_joblib_bench('out/test.pkl', dataset, | |
('joblib (%s - zlib 3)' % | |
joblib.__version__), | |
DATASET, f, compress=3) | |
if 'numpy-compressed' or STRATEGIES: | |
run_numpy_compressed_bench('out/test.pkl', dataset, | |
'numpy compressed', | |
DATASET, f) | |
if __name__ == '__main__': | |
bench() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import pandas as pd | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
# General configuration variables | |
# Script configuration variables: | |
CSV_FILE = '/tmp/comparison_results.csv' | |
PNG_FILE = '/tmp/comparison_results.png' | |
DATASET = 'lfw_people' | |
TRIES = 1 | |
SHOW_PLOT = True | |
DATASET_DESC = {'lfw_people': 'Labeled Faces in the Wild dataset (LFW)', | |
'big_array': 'Numpy array with random values (~700MB)', | |
'big_dict': 'Dictionary with 1M random keys/values', | |
'big_list': 'List of 10M random values'} | |
############################################################################### | |
# Plot function | |
def generate_plots(): | |
"""Generate a nice matplotlib figure.""" | |
if not os.path.exists(CSV_FILE): | |
print("CSV file doesn't exist, exiting") | |
return | |
df = pd.read_csv(CSV_FILE) | |
df = df[df.dataset == DATASET] # filter on dataset | |
if not len(df): | |
print("Nothing to plot, exiting") | |
return | |
# Set up the matplotlib figure | |
sns.set(style="whitegrid", context="talk") | |
f, (dump_axe, load_axe, mem_dump_axe, mem_load_axe, disk_axe) = \ | |
plt.subplots(1, 5, figsize=(9, 4.2), sharey=True, | |
gridspec_kw=dict(wspace=.7, right=.947, bottom=.005, | |
top=.85, left=.14)) | |
df.strategy = [s.replace(' ', '\n', 1) | |
.replace('0.10.0.dev0', 'dev') | |
.replace(' -', ', ') | |
for s in df.strategy] | |
strategies = df.strategy | |
dump_times = df.dump | |
load_times = df.load | |
memory_dump = df.mem_dump | |
memory_load = df.mem_load | |
disk_used = df.disk | |
plt.text(.005, .96, '{0}'.format(DATASET_DESC[DATASET]), size=13, | |
transform=f.transFigure) | |
sns.barplot(dump_times, strategies, palette="Set3", ax=dump_axe) | |
dump_axe.set_title("Dump time") | |
dump_axe.set_xlabel("") | |
dump_axe.set_ylabel("") | |
for i, v in enumerate(strategies.unique()): | |
value = df[df.strategy == v].dump.mean() | |
dump_axe.text(value + 0.01 * max(dump_times), | |
i + .15, "{0:.2G}s".format(value), | |
color='black', style='italic') | |
dump_axe.set_xticks(()) | |
sns.barplot(load_times, strategies, palette="Set3", ax=load_axe) | |
load_axe.set_title("Load time") | |
load_axe.set_xlabel("") | |
load_axe.set_ylabel("") | |
for i, v in enumerate(strategies.unique()): | |
value = df[df.strategy == v].load.mean() | |
load_axe.text(value + 0.01 * max(load_times), | |
i + .15, "{0:.2G}s".format(value), | |
color='black', style='italic') | |
load_axe.set_xticks(()) | |
sns.barplot(memory_dump, strategies, palette="Set3", ax=mem_dump_axe) | |
mem_dump_axe.set_title("Memory used\nwith dump") | |
mem_dump_axe.set_xlabel("") | |
mem_dump_axe.set_ylabel("") | |
for i, v in enumerate(strategies.unique()): | |
value = df[df.strategy == v].mem_dump.mean() | |
mem_dump_axe.text(value + 0.01 * max(memory_dump), | |
i + .15, "{0:.0f}MB".format(value), | |
color='black', style='italic') | |
mem_dump_axe.set_xticks(()) | |
sns.barplot(memory_load, strategies, palette="Set3", ax=mem_load_axe) | |
mem_load_axe.set_title("Memory used\nwith load") | |
mem_load_axe.set_xlabel("") | |
mem_load_axe.set_ylabel("") | |
for i, v in enumerate(strategies.unique()): | |
value = df[df.strategy == v].mem_load.mean() | |
mem_load_axe.text(value + 0.01 * max(memory_load), | |
i + .15, "{0:.0f}MB".format(value), | |
color='black', style='italic') | |
mem_load_axe.set_xticks(()) | |
sns.barplot(disk_used, strategies, palette="Set3", ax=disk_axe) | |
disk_axe.set_title("Disk used") | |
disk_axe.set_xlabel("") | |
disk_axe.xaxis.tick_top() | |
disk_axe.set_ylabel("") | |
for i, v in enumerate(strategies.unique()): | |
value = df[df.strategy == v].disk.mean() | |
disk_axe.text(value + 0.01 * max(disk_used), | |
i + .15, "{0:.0f}MB".format(value), | |
color='black', style='italic') | |
disk_axe.set_xticks(()) | |
sns.despine(bottom=True) | |
plt.savefig(PNG_FILE, dpi=100) | |
if SHOW_PLOT: | |
plt.show() | |
if __name__ == '__main__': | |
generate_plots() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
How to use this gist
/tmp/comparison_results.csv
but one can play withCSV_FILE
variable)You should end up with this figure: