Last active
February 8, 2017 14:22
-
-
Save rth/b34e814baceddbdd7fa362268f6b629e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This script aims to benchmark different parallelization options for pairwise metrics in scikit-learn. | |
The results can be found in https://github.com/scikit-learn/scikit-learn/issues/8216 | |
The environement is setup with, | |
conda create -n sklearn-env scikit-learn==0.18.1 jupyter python==3.5 | |
and this benchmark should be run with, | |
ipython pairwise_distances_benchmark.py False euclidean | |
^^ ^^ | |
sparse_input metric | |
""" | |
import tempfile | |
import os | |
import sys | |
import shutil | |
import numpy as np | |
import scipy.sparse | |
from sklearn.metrics.pairwise import PAIRWISE_DISTANCE_FUNCTIONS | |
from sklearn.externals.joblib import Parallel, delayed, dump, load | |
from sklearn.utils import gen_even_slices | |
from IPython import get_ipython | |
ipython = get_ipython() | |
np.random.seed(99999) | |
sparse_input = sys.argv[1] == 'True' | |
metric = sys.argv[2] | |
def mmap_pdist_func(func, X, Y, Z, s): | |
Z[:, s] = func(X, Y[s]) | |
# The _parallel_pairwise function from scikit-learn | |
# updated with extra arguments to Parallel | |
def _parallel_pairwise(X, Y, func, n_jobs, backend=None, mmap_result=False, **kwds): | |
"""Break the pairwise matrix in n_jobs even slices | |
and compute them in parallel""" | |
if n_jobs < 0: | |
n_jobs = max(cpu_count() + 1 + n_jobs, 1) | |
if Y is None: | |
Y = X | |
if n_jobs == 1: | |
# Special case to avoid picklability checks in delayed | |
return func(X, Y, **kwds) | |
# TODO: in some cases, backend='threading' may be appropriate | |
if not mmap_result: | |
ret = Parallel(n_jobs=n_jobs, verbose=0, backend=backend, mmap_mode='r')( | |
delayed(func)(X, Y[s], **kwds) | |
for s in gen_even_slices(Y.shape[0], n_jobs)) | |
return np.hstack(ret) | |
else: | |
Z = np.empty((X.shape[0], Y.shape[0]), dtype=X.dtype) | |
Parallel(n_jobs=n_jobs, verbose=0, backend=backend, mmap_mode='r+')( | |
delayed(mmap_pdist_func)(func, X, Y, Z, s, **kwds) | |
for s in gen_even_slices(Y.shape[0], n_jobs)) | |
return Z | |
for n_x, n_y, n_dim in [(100000, 1000, 1000), | |
(10000, 10000, 1000), | |
(10000, 10000, 10)]: | |
if sparse_input: | |
n_dim *= 10 # as by default density=0.01 | |
X = scipy.sparse.random(n_x, n_dim, format='csr') | |
Y = scipy.sparse.random(n_y, n_dim, format='csr') | |
else: | |
X = np.random.rand(n_x, n_dim) | |
Y = np.random.rand(n_y, n_dim) | |
print('='*80) | |
print('\n# sparse={}, n_x={}, n_y={}, n_dim={}'.format(sparse_input, n_x, n_y, n_dim)) | |
print('# X array: {} GB, Y array {} GB, result array {} GB'.format( | |
# sparse arrays take ~twice as much space, as we need to store data, indices, indptr | |
X.data.nbytes*1e-9 if not sparse_input else X.data.nbytes*1e-9*2, | |
Y.data.nbytes*1e-9 if not sparse_input else Y.data.nbytes*1e-9*2, | |
(8*X.shape[0]*Y.shape[0])*1e-9)) | |
print("# metric =", metric) | |
print('='*80) | |
for parallel_pars in [{'backend': 'multiprocessing', 'mmap_result': False, 'MKL_NUM_THREADS': 8}, | |
{'backend': 'multiprocessing', 'mmap_result': True, 'MKL_NUM_THREADS': 8}, | |
{'backend': 'threading', 'MKL_NUM_THREADS': 8}, | |
{'backend': 'multiprocessing', 'mmap_result': False, 'MKL_NUM_THREADS': 1}, | |
{'backend': 'multiprocessing', 'mmap_result': True, 'MKL_NUM_THREADS': 1}, | |
{'backend': 'threading', 'MKL_NUM_THREADS': 1}, | |
]: | |
print('\n## ', parallel_pars) | |
MKL_NUM_THREADS = parallel_pars.pop('MKL_NUM_THREADS', 8) | |
# set the number of threads used by numpy | |
os.environ['MKL_NUM_THREADS'] = str(MKL_NUM_THREADS) | |
pdist_func = PAIRWISE_DISTANCE_FUNCTIONS[metric] | |
for n_jobs in [1, 2, 4, 8, 16]: | |
print('n_jobs=', n_jobs, ' => ', end='') | |
ipython.magic("timeit -n 1 -r 1 _parallel_pairwise(X, Y, pdist_func, n_jobs=n_jobs, **parallel_pars)") | |
del os.environ['MKL_NUM_THREADS'] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment