Last active
April 26, 2017 14:57
-
-
Save ClementC/6bc1824e363741fa023ef3b245422270 to your computer and use it in GitHub Desktop.
Minimal reproducible example with solutions proposals and benchmark.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2017-04-24T18:44:49.077314Z", | |
"start_time": "2017-04-24T18:44:49.069926Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"from sklearn.pipeline import Pipeline\n", | |
"from sklearn.preprocessing import FunctionTransformer\n", | |
"from sklearn.feature_extraction.text import TfidfTransformer\n", | |
"from sklearn.naive_bayes import MultinomialNB\n", | |
"from sklearn.cross_validation import cross_val_predict\n", | |
"from scipy import sparse\n", | |
"import random\n", | |
"import math\n", | |
"import numpy as np\n", | |
"from unittest.mock import patch\n", | |
"import pickle\n", | |
"from uuid import uuid4\n", | |
"import time" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2017-04-24T18:00:47.988517Z", | |
"start_time": "2017-04-24T18:00:47.946694Z" | |
}, | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def get_columns(X, cols=None):\n", | |
" return X[:, cols]\n", | |
"\n", | |
"n_rows = 100\n", | |
"n_cols = 10e7\n", | |
"sparsity = 0.0000001\n", | |
"X = sparse.rand(n_rows, n_cols, sparsity, \"csr\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 39, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2017-04-24T19:03:27.144434Z", | |
"start_time": "2017-04-24T19:02:51.702399Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"99900000\r" | |
] | |
} | |
], | |
"source": [ | |
"cols_choice = []\n", | |
"for elem in range(int(n_cols)):\n", | |
" if elem % 100000 == 0:\n", | |
" print(elem, end=\"\\r\")\n", | |
" if random.random() >= .2:\n", | |
" cols_choice.append(elem)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2017-04-24T18:40:37.847166Z", | |
"start_time": "2017-04-24T18:40:36.452360Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"p = Pipeline([\n", | |
" (\"columns_selection\", FunctionTransformer(get_columns, accept_sparse=True, kw_args={\"cols\": cols_choice})),\n", | |
" (\"TF.IDF\", TfidfTransformer()),\n", | |
" (\"classifier\", MultinomialNB())\n", | |
"])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2017-04-24T18:01:25.842442Z", | |
"start_time": "2017-04-24T18:01:25.665204Z" | |
}, | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"y = np.random.randint(0, 2, size=(n_rows,))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2017-04-24T18:01:25.978724Z", | |
"start_time": "2017-04-24T18:01:25.844029Z" | |
}, | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"np.random.seed(1234)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2017-04-24T18:13:24.094364Z", | |
"start_time": "2017-04-24T18:01:25.983908Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[Parallel(n_jobs=-1)]: Done 1 tasks | elapsed: 1.9min\n", | |
"[Parallel(n_jobs=-1)]: Done 2 out of 10 | elapsed: 11.3min remaining: 45.1min\n", | |
"[Parallel(n_jobs=-1)]: Done 3 out of 10 | elapsed: 11.3min remaining: 26.4min\n", | |
"[Parallel(n_jobs=-1)]: Done 4 out of 10 | elapsed: 11.3min remaining: 17.0min\n", | |
"[Parallel(n_jobs=-1)]: Done 5 out of 10 | elapsed: 11.3min remaining: 11.3min\n", | |
"[Parallel(n_jobs=-1)]: Done 6 out of 10 | elapsed: 11.3min remaining: 7.5min\n", | |
"[Parallel(n_jobs=-1)]: Done 7 out of 10 | elapsed: 11.3min remaining: 4.9min\n", | |
"[Parallel(n_jobs=-1)]: Done 8 out of 10 | elapsed: 11.3min remaining: 2.8min\n", | |
"[Parallel(n_jobs=-1)]: Done 10 out of 10 | elapsed: 11.9min remaining: 0.0s\n", | |
"[Parallel(n_jobs=-1)]: Done 10 out of 10 | elapsed: 11.9min finished\n", | |
"CPU times: user 11min 17s, sys: 11.7 s, total: 11min 28s\n", | |
"Wall time: 11min 57s\n" | |
] | |
} | |
], | |
"source": [ | |
"%time y_predict = cross_val_predict(p, X, y, cv=10, n_jobs=-1, verbose=50)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2017-04-24T18:15:40.835623Z", | |
"start_time": "2017-04-24T18:15:40.828193Z" | |
}, | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"np.random.seed(1234)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2017-04-24T18:17:09.280836Z", | |
"start_time": "2017-04-24T18:15:41.003637Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[Parallel(n_jobs=-1)]: Done 1 tasks | elapsed: 45.4s\n", | |
"[Parallel(n_jobs=-1)]: Done 2 out of 10 | elapsed: 51.8s remaining: 3.5min\n", | |
"[Parallel(n_jobs=-1)]: Done 3 out of 10 | elapsed: 52.2s remaining: 2.0min\n", | |
"[Parallel(n_jobs=-1)]: Done 4 out of 10 | elapsed: 56.7s remaining: 1.4min\n", | |
"[Parallel(n_jobs=-1)]: Done 5 out of 10 | elapsed: 1.1min remaining: 1.1min\n", | |
"[Parallel(n_jobs=-1)]: Done 6 out of 10 | elapsed: 1.1min remaining: 43.9s\n", | |
"[Parallel(n_jobs=-1)]: Done 7 out of 10 | elapsed: 1.2min remaining: 31.2s\n", | |
"[Parallel(n_jobs=-1)]: Done 8 out of 10 | elapsed: 1.3min remaining: 19.0s\n", | |
"[Parallel(n_jobs=-1)]: Done 10 out of 10 | elapsed: 1.4min remaining: 0.0s\n", | |
"[Parallel(n_jobs=-1)]: Done 10 out of 10 | elapsed: 1.4min finished\n", | |
"CPU times: user 26.2 s, sys: 6.5 s, total: 32.7 s\n", | |
"Wall time: 1min 28s\n" | |
] | |
} | |
], | |
"source": [ | |
"%time with patch(\"sklearn.cross_validation.clone\", lambda x: x):\\\n", | |
" y_predict_no_clone = cross_val_predict(p, X, y, cv=10, n_jobs=-1, verbose=50)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2017-04-24T18:42:36.017501Z", | |
"start_time": "2017-04-24T18:42:36.010079Z" | |
}, | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"np.random.seed(1234)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2017-04-24T18:46:01.063110Z", | |
"start_time": "2017-04-24T18:45:09.095433Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[Parallel(n_jobs=-1)]: Done 1 tasks | elapsed: 35.1s\n", | |
"[Parallel(n_jobs=-1)]: Done 2 out of 10 | elapsed: 37.3s remaining: 2.5min\n", | |
"[Parallel(n_jobs=-1)]: Done 3 out of 10 | elapsed: 38.7s remaining: 1.5min\n", | |
"[Parallel(n_jobs=-1)]: Done 4 out of 10 | elapsed: 39.2s remaining: 58.8s\n", | |
"[Parallel(n_jobs=-1)]: Done 5 out of 10 | elapsed: 40.6s remaining: 40.6s\n", | |
"[Parallel(n_jobs=-1)]: Done 6 out of 10 | elapsed: 42.9s remaining: 28.6s\n", | |
"[Parallel(n_jobs=-1)]: Done 7 out of 10 | elapsed: 44.0s remaining: 18.9s\n", | |
"[Parallel(n_jobs=-1)]: Done 8 out of 10 | elapsed: 44.6s remaining: 11.2s\n", | |
"[Parallel(n_jobs=-1)]: Done 10 out of 10 | elapsed: 49.0s remaining: 0.0s\n", | |
"[Parallel(n_jobs=-1)]: Done 10 out of 10 | elapsed: 49.0s finished\n", | |
"52.0s\n" | |
] | |
} | |
], | |
"source": [ | |
"start_time = time.time()\n", | |
"with open(\"/dev/shm/test.pickle\", \"wb\") as f:\n", | |
" pickle.dump(p, f)\n", | |
"\n", | |
"with patch(\"sklearn.cross_validation.clone\", lambda x: pickle.load(open(\"/dev/shm/test.pickle\", \"rb\"))):\n", | |
" y_predict_load_pickle = cross_val_predict(p, X, y, cv=10, n_jobs=-1, verbose=50)\n", | |
"print(\"%.1fs\" % (time.time() - start_time))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 34, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2017-04-24T18:46:01.068681Z", | |
"start_time": "2017-04-24T18:46:01.065058Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"assert (y_predict == y_predict_no_clone).all()\n", | |
"assert (y_predict == y_predict_load_pickle).all()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 36, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2017-04-24T18:53:53.149409Z", | |
"start_time": "2017-04-24T18:53:53.070914Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Linux-3.16.0-4-amd64-x86_64-with-debian-8.4\n", | |
"Python 3.4.0 (default, Aug 6 2014, 08:30:09) \n", | |
"[GCC 4.9.0]\n", | |
"NumPy 1.12.1\n", | |
"SciPy 0.19.0\n", | |
"Scikit-Learn 0.18.1\n" | |
] | |
} | |
], | |
"source": [ | |
"import platform; print(platform.platform())\n", | |
"import sys; print(\"Python\", sys.version)\n", | |
"import numpy; print(\"NumPy\", numpy.__version__)\n", | |
"import scipy; print(\"SciPy\", scipy.__version__)\n", | |
"import sklearn; print(\"Scikit-Learn\", sklearn.__version__)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.4.0" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment