Created
April 10, 2017 16:01
-
-
Save dspp779/5a9597e2d8a2518b80fb0ad191ea8463 to your computer and use it in GitHub Desktop.
Text classification with SVM example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Support vector machines and machine learning on documents" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Support vector machine (SVM), which is widely regarded as one of the best text classification algorithms (though it’s also a bit slower than naïve Bayes)." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Support vector classifier (SVC) is a powerful and widely used memory-based classifier is the nonlinear. Like KNN, nonlinear SVC makes predictions by the weighted average of the labels of similar examples (measured by a kernel function). However, only the support vectors, i.e., examples falling onto or inside the margin, can have positive weights and need to be remembered. In practice, SVC usually remembers much fewer examples than KNN does. Another difference is that SVC is not an lazy learner---the weights are trained eagerly in the training phase." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"['alt.atheism',\n", | |
" 'comp.graphics',\n", | |
" 'comp.os.ms-windows.misc',\n", | |
" 'comp.sys.ibm.pc.hardware',\n", | |
" 'comp.sys.mac.hardware',\n", | |
" 'comp.windows.x',\n", | |
" 'misc.forsale',\n", | |
" 'rec.autos',\n", | |
" 'rec.motorcycles',\n", | |
" 'rec.sport.baseball',\n", | |
" 'rec.sport.hockey',\n", | |
" 'sci.crypt',\n", | |
" 'sci.electronics',\n", | |
" 'sci.med',\n", | |
" 'sci.space',\n", | |
" 'soc.religion.christian',\n", | |
" 'talk.politics.guns',\n", | |
" 'talk.politics.mideast',\n", | |
" 'talk.politics.misc',\n", | |
" 'talk.religion.misc']" | |
] | |
}, | |
"execution_count": 1, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from sklearn.datasets import load_files\n", | |
"twenty_train = load_files('twenty_newsgroups/20news-bydate-train', encoding='latin1')\n", | |
"twenty_train.target_names" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Extracting features from text files" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"In order to perform machine learning on text documents, we first need to turn the text content into numerical feature vectors." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Tokenizing text with scikit-learn" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Text preprocessing, tokenizing and filtering of stopwords are included in a high level component that is able to build a dictionary of features and transform documents to feature vectors." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(11314, 130107)" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from sklearn.feature_extraction.text import CountVectorizer\n", | |
"count_vect = CountVectorizer()\n", | |
"X_train_counts = count_vect.fit_transform(twenty_train.data)\n", | |
"X_train_counts.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"56283" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"count_vect.vocabulary_.get('for')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"CountVectorizer supports counts of N-grams of words or consecutive characters. Once fitted, the vectorizer has built a dictionary of feature indices:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(11314, 8069416)" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ngram_count_vect = CountVectorizer(ngram_range=(1, 5))\n", | |
"XX_train_counts = ngram_count_vect.fit_transform(twenty_train.data)\n", | |
"XX_train_counts.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"627642" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ngram_count_vect.vocabulary_.get('algorithm for')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### From occurrences to frequencies" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Occurrence count is a good start but there is an issue: longer documents will have higher average count values than shorter documents, even though they might talk about the same topics.\n", | |
"\n", | |
"To avoid these potential discrepancies it suffices to divide the number of occurrences of each word in a document by the total number of words in the document: these new features are called tf for Term Frequencies.\n", | |
"\n", | |
"Another refinement on top of tf is to downscale weights for words that occur in many documents in the corpus and are therefore less informative than those that occur only in a smaller portion of the corpus.\n", | |
"This downscaling is called tf–idf for “Term Frequency times Inverse Document Frequency”." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Both tf and tf–idf can be computed as follows:" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### TfidfTransformer\n", | |
"Equivalent to CountVectorizer followed by TfidfTransformer" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(11314, 130107)" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from sklearn.feature_extraction.text import TfidfTransformer\n", | |
"tf_transformer = TfidfTransformer(use_idf=False).fit(X_train_counts)\n", | |
"X_train_tf = tf_transformer.transform(X_train_counts)\n", | |
"X_train_tf.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(11314, 130107)" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"tfidf_transformer = TfidfTransformer()\n", | |
"X_train_tfidf = tfidf_transformer.fit_transform(X_train_counts)\n", | |
"X_train_tfidf.shape" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### TfidfVectorizer\n", | |
"Equivalent to CountVectorizer followed by TfidfTransformer" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(11314, 130107)" | |
] | |
}, | |
"execution_count": 27, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from sklearn.feature_extraction.text import TfidfVectorizer\n", | |
"\n", | |
"tfidf_vectorizer = TfidfVectorizer()\n", | |
"X_train_tfidf = tfidf_vectorizer.fit_transform(twenty_train.data)\n", | |
"X_train_tfidf.shape" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Building a pipeline" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"In order to make the vectorizer => transformer => classifier easier to work with, scikit-learn provides a Pipeline class that behaves like a compound classifier" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from sklearn.pipeline import Pipeline" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Support vector machine (SVM), which is widely regarded as one of the best text classification algorithms (although it’s also a bit slower than naïve Bayes). We can change the learner by just plugging a different classifier object into our pipeline:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from sklearn.linear_model import SGDClassifier\n", | |
"from sklearn.svm import SVC, LinearSVC, NuSVC" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 35, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"text_clf = Pipeline([('vect', TfidfVectorizer()),\n", | |
" ('clf', LinearSVC()),\n", | |
"])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 36, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Pipeline(steps=[('vect', TfidfVectorizer(analyzer='word', binary=False, decode_error='strict',\n", | |
" dtype=<class 'numpy.int64'>, encoding='utf-8', input='content',\n", | |
" lowercase=True, max_df=1.0, max_features=None, min_df=1,\n", | |
" ngram_range=(1, 1), norm='l2', preprocessor=None, smooth_idf=True,\n", | |
" ...ax_iter=1000,\n", | |
" multi_class='ovr', penalty='l2', random_state=None, tol=0.0001,\n", | |
" verbose=0))])" | |
] | |
}, | |
"execution_count": 36, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"text_clf.fit(twenty_train.data, twenty_train.target)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### load test data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"twenty_test = load_files('twenty_newsgroups/20news-bydate-train', encoding='latin1')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 37, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"predicted = text_clf.predict(twenty_test.data)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 38, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.99902775322609161" | |
] | |
}, | |
"execution_count": 38, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"import numpy as np\n", | |
"np.mean(predicted == twenty_test.target) " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 39, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" precision recall f1-score support\n", | |
"\n", | |
" alt.atheism 1.00 1.00 1.00 480\n", | |
" comp.graphics 1.00 1.00 1.00 584\n", | |
" comp.os.ms-windows.misc 1.00 1.00 1.00 591\n", | |
"comp.sys.ibm.pc.hardware 0.99 1.00 1.00 590\n", | |
" comp.sys.mac.hardware 1.00 1.00 1.00 578\n", | |
" comp.windows.x 1.00 1.00 1.00 593\n", | |
" misc.forsale 1.00 1.00 1.00 585\n", | |
" rec.autos 1.00 1.00 1.00 594\n", | |
" rec.motorcycles 1.00 1.00 1.00 598\n", | |
" rec.sport.baseball 1.00 1.00 1.00 597\n", | |
" rec.sport.hockey 1.00 1.00 1.00 600\n", | |
" sci.crypt 1.00 1.00 1.00 595\n", | |
" sci.electronics 1.00 0.99 1.00 591\n", | |
" sci.med 1.00 1.00 1.00 594\n", | |
" sci.space 1.00 1.00 1.00 593\n", | |
" soc.religion.christian 1.00 1.00 1.00 599\n", | |
" talk.politics.guns 1.00 1.00 1.00 546\n", | |
" talk.politics.mideast 1.00 1.00 1.00 564\n", | |
" talk.politics.misc 1.00 1.00 1.00 465\n", | |
" talk.religion.misc 1.00 1.00 1.00 377\n", | |
"\n", | |
" avg / total 1.00 1.00 1.00 11314\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"from sklearn import metrics\n", | |
"print(metrics.classification_report(twenty_test.target, predicted,\n", | |
" target_names=twenty_test.target_names))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Parameter tuning using grid search" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We’ve already encountered some parameters such as use_idf in the TfidfTransformer. Classifiers tend to have many parameters as well; e.g., MultinomialNB includes a smoothing parameter alpha and SGDClassifier has a penalty parameter alpha and configurable loss and penalty terms in the objective function (see the module documentation, or use the Python help function, to get a description of these).\n", | |
"\n", | |
"Instead of tweaking the parameters of the various components of the chain, it is possible to run an exhaustive search of the best parameters on a grid of possible values. We try out all classifiers on either words or bigrams, with or without idf, and with a penalty parameter of either 0.01 or 0.001 for the linear SVM:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 43, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from sklearn.model_selection import GridSearchCV\n", | |
"parameters = {'vect__ngram_range': [(1, 1), (1, 2)],\n", | |
" 'vect__use_idf': (True, False),\n", | |
" 'clf__C': (1.0, 0.1, 1e-2, 1e-3),\n", | |
"}" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Obviously, such an exhaustive search can be expensive. If we have multiple CPU cores at our disposal, we can tell the grid searcher to try these eight parameter combinations in parallel with the n_jobs parameter. If we give this parameter a value of -1, grid search will detect how many cores are installed and uses them all:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 44, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"gs_clf = GridSearchCV(text_clf, parameters, n_jobs=-1)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"The grid search instance behaves like a normal scikit-learn model. Let’s perform the search on a smaller subset of the training data to speed up the computation:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 45, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"gs_clf = gs_clf.fit(twenty_train.data, twenty_train.target)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"The result of calling fit on a GridSearchCV object is a classifier that we can use to predict:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 46, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'soc.religion.christian'" | |
] | |
}, | |
"execution_count": 46, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"twenty_train.target_names[gs_clf.predict(['God is love'])[0]]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 47, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.9231041187908785" | |
] | |
}, | |
"execution_count": 47, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"gs_clf.best_score_" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 48, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"clf__C: 1.0\n", | |
"vect__ngram_range: (1, 2)\n", | |
"vect__use_idf: True\n" | |
] | |
} | |
], | |
"source": [ | |
"for param_name in sorted(parameters.keys()):\n", | |
" print(\"%s: %r\" % (param_name, gs_clf.best_params_[param_name]))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 50, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"clf = gs_clf.best_estimator_" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 51, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"predicted = clf.predict(twenty_test.data)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 52, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.9996464557185788" | |
] | |
}, | |
"execution_count": 52, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"import numpy as np\n", | |
"np.mean(predicted == twenty_test.target) " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 53, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" precision recall f1-score support\n", | |
"\n", | |
" alt.atheism 1.00 1.00 1.00 480\n", | |
" comp.graphics 1.00 1.00 1.00 584\n", | |
" comp.os.ms-windows.misc 1.00 1.00 1.00 591\n", | |
"comp.sys.ibm.pc.hardware 1.00 1.00 1.00 590\n", | |
" comp.sys.mac.hardware 1.00 1.00 1.00 578\n", | |
" comp.windows.x 1.00 1.00 1.00 593\n", | |
" misc.forsale 1.00 1.00 1.00 585\n", | |
" rec.autos 1.00 1.00 1.00 594\n", | |
" rec.motorcycles 1.00 1.00 1.00 598\n", | |
" rec.sport.baseball 1.00 1.00 1.00 597\n", | |
" rec.sport.hockey 1.00 1.00 1.00 600\n", | |
" sci.crypt 1.00 1.00 1.00 595\n", | |
" sci.electronics 1.00 1.00 1.00 591\n", | |
" sci.med 1.00 1.00 1.00 594\n", | |
" sci.space 1.00 1.00 1.00 593\n", | |
" soc.religion.christian 1.00 1.00 1.00 599\n", | |
" talk.politics.guns 1.00 1.00 1.00 546\n", | |
" talk.politics.mideast 1.00 1.00 1.00 564\n", | |
" talk.politics.misc 1.00 1.00 1.00 465\n", | |
" talk.religion.misc 1.00 1.00 1.00 377\n", | |
"\n", | |
" avg / total 1.00 1.00 1.00 11314\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"from sklearn import metrics\n", | |
"print(metrics.classification_report(twenty_test.target, predicted,\n", | |
" target_names=twenty_test.target_names))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 49, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'mean_fit_time': array([ 6.76014423, 6.60558637, 30.79757826, 29.02015003,\n", | |
" 5.89655495, 5.21485591, 25.65494792, 22.93828098,\n", | |
" 5.59868479, 5.59316071, 30.86746566, 24.40903004,\n", | |
" 5.04613773, 5.09751471, 26.44545277, 19.00444237]),\n", | |
" 'mean_score_time': array([ 1.84185378, 1.83289027, 5.56224203, 4.75552535, 1.99876022,\n", | |
" 1.82617879, 5.16473365, 4.67290862, 1.92498589, 1.83058763,\n", | |
" 5.06621178, 4.66901326, 1.92053254, 1.84296004, 5.24075532,\n", | |
" 3.45949546]),\n", | |
" 'mean_test_score': array([ 0.91894997, 0.8902245 , 0.92310412, 0.89985858, 0.8959696 ,\n", | |
" 0.82950327, 0.89853279, 0.8432915 , 0.82791232, 0.64221319,\n", | |
" 0.83515998, 0.65494078, 0.70673502, 0.39199222, 0.71813682,\n", | |
" 0.38403748]),\n", | |
" 'mean_train_score': array([ 0.99929296, 0.99500614, 0.99973487, 0.99907207, 0.98258771,\n", | |
" 0.93167821, 0.9913825 , 0.96199421, 0.90569246, 0.71398305,\n", | |
" 0.94175374, 0.74602304, 0.76993005, 0.42487153, 0.82477497,\n", | |
" 0.43676011]),\n", | |
" 'param_clf__C': masked_array(data = [1.0 1.0 1.0 1.0 0.1 0.1 0.1 0.1 0.01 0.01 0.01 0.01 0.001 0.001 0.001\n", | |
" 0.001],\n", | |
" mask = [False False False False False False False False False False False False\n", | |
" False False False False],\n", | |
" fill_value = ?),\n", | |
" 'param_vect__ngram_range': masked_array(data = [(1, 1) (1, 1) (1, 2) (1, 2) (1, 1) (1, 1) (1, 2) (1, 2) (1, 1) (1, 1)\n", | |
" (1, 2) (1, 2) (1, 1) (1, 1) (1, 2) (1, 2)],\n", | |
" mask = [False False False False False False False False False False False False\n", | |
" False False False False],\n", | |
" fill_value = ?),\n", | |
" 'param_vect__use_idf': masked_array(data = [True False True False True False True False True False True False True\n", | |
" False True False],\n", | |
" mask = [False False False False False False False False False False False False\n", | |
" False False False False],\n", | |
" fill_value = ?),\n", | |
" 'params': ({'clf__C': 1.0,\n", | |
" 'vect__ngram_range': (1, 1),\n", | |
" 'vect__use_idf': True},\n", | |
" {'clf__C': 1.0, 'vect__ngram_range': (1, 1), 'vect__use_idf': False},\n", | |
" {'clf__C': 1.0, 'vect__ngram_range': (1, 2), 'vect__use_idf': True},\n", | |
" {'clf__C': 1.0, 'vect__ngram_range': (1, 2), 'vect__use_idf': False},\n", | |
" {'clf__C': 0.1, 'vect__ngram_range': (1, 1), 'vect__use_idf': True},\n", | |
" {'clf__C': 0.1, 'vect__ngram_range': (1, 1), 'vect__use_idf': False},\n", | |
" {'clf__C': 0.1, 'vect__ngram_range': (1, 2), 'vect__use_idf': True},\n", | |
" {'clf__C': 0.1, 'vect__ngram_range': (1, 2), 'vect__use_idf': False},\n", | |
" {'clf__C': 0.01, 'vect__ngram_range': (1, 1), 'vect__use_idf': True},\n", | |
" {'clf__C': 0.01, 'vect__ngram_range': (1, 1), 'vect__use_idf': False},\n", | |
" {'clf__C': 0.01, 'vect__ngram_range': (1, 2), 'vect__use_idf': True},\n", | |
" {'clf__C': 0.01, 'vect__ngram_range': (1, 2), 'vect__use_idf': False},\n", | |
" {'clf__C': 0.001, 'vect__ngram_range': (1, 1), 'vect__use_idf': True},\n", | |
" {'clf__C': 0.001, 'vect__ngram_range': (1, 1), 'vect__use_idf': False},\n", | |
" {'clf__C': 0.001, 'vect__ngram_range': (1, 2), 'vect__use_idf': True},\n", | |
" {'clf__C': 0.001, 'vect__ngram_range': (1, 2), 'vect__use_idf': False}),\n", | |
" 'rank_test_score': array([ 2, 6, 1, 3, 5, 9, 4, 7, 10, 14, 8, 13, 12, 15, 11, 16], dtype=int32),\n", | |
" 'split0_test_score': array([ 0.9189404 , 0.88715232, 0.92450331, 0.89536424, 0.89456954,\n", | |
" 0.83125828, 0.89933775, 0.8384106 , 0.82701987, 0.64211921,\n", | |
" 0.83417219, 0.65271523, 0.70728477, 0.38913907, 0.7202649 ,\n", | |
" 0.38172185]),\n", | |
" 'split0_train_score': array([ 0.99907149, 0.9948269 , 0.99960207, 0.99907149, 0.98063404,\n", | |
" 0.93142327, 0.99071495, 0.96206393, 0.9068842 , 0.71348985,\n", | |
" 0.94057567, 0.7446611 , 0.76760844, 0.4260512 , 0.82809391,\n", | |
" 0.43891763]),\n", | |
" 'split1_test_score': array([ 0.91704214, 0.88682746, 0.92022263, 0.89689902, 0.89239332,\n", | |
" 0.82745826, 0.8915982 , 0.84097535, 0.82719321, 0.6315929 ,\n", | |
" 0.82878346, 0.6477604 , 0.70076862, 0.38271932, 0.71137026,\n", | |
" 0.38112908]),\n", | |
" 'split1_train_score': array([ 0.99973478, 0.99509349, 1. , 0.99946957, 0.98435221,\n", | |
" 0.93409362, 0.99270654, 0.96260443, 0.90545021, 0.71582018,\n", | |
" 0.94377404, 0.74910489, 0.7696592 , 0.42235778, 0.82164169,\n", | |
" 0.43508818]),\n", | |
" 'split2_test_score': array([ 0.92087095, 0.89670738, 0.92458842, 0.90732873, 0.90095592,\n", | |
" 0.82979288, 0.90467339, 0.85050451, 0.82952735, 0.65294742,\n", | |
" 0.8425385 , 0.66436537, 0.71216144, 0.40414233, 0.72278279,\n", | |
" 0.38927244]),\n", | |
" 'split2_train_score': array([ 0.9990726 , 0.99509804, 0.99960254, 0.99867515, 0.98277689,\n", | |
" 0.92951775, 0.99072602, 0.96131426, 0.90474298, 0.71263911,\n", | |
" 0.9409115 , 0.74430313, 0.77252252, 0.42620562, 0.8245893 ,\n", | |
" 0.43627451]),\n", | |
" 'std_fit_time': array([ 0.04834512, 0.32860575, 0.84135849, 1.65121325, 0.25290831,\n", | |
" 0.21772469, 0.44783278, 1.50370171, 0.44456331, 0.65811513,\n", | |
" 0.15568821, 4.48567001, 0.30451781, 0.51476054, 0.43081499,\n", | |
" 4.96984552]),\n", | |
" 'std_score_time': array([ 0.0252354 , 0.0134656 , 0.42357746, 0.02507232, 0.16907588,\n", | |
" 0.05125834, 0.26886112, 0.05047332, 0.12568297, 0.0278848 ,\n", | |
" 0.14044775, 0.0266534 , 0.11198455, 0.05223177, 0.2779491 ,\n", | |
" 0.88350279]),\n", | |
" 'std_test_score': array([ 0.00156274, 0.00458115, 0.00203849, 0.00531369, 0.00363253,\n", | |
" 0.00156533, 0.00536694, 0.00520152, 0.00114298, 0.00871608,\n", | |
" 0.00565745, 0.00695787, 0.00466621, 0.00897384, 0.00489538,\n", | |
" 0.00370567]),\n", | |
" 'std_train_score': array([ 0.00031242, 0.00012676, 0.00018747, 0.00032432, 0.00152382,\n", | |
" 0.00187677, 0.00093625, 0.00052901, 0.00089078, 0.00134468,\n", | |
" 0.00143513, 0.0021841 , 0.00201529, 0.00177861, 0.00263738,\n", | |
" 0.00160063])}" | |
] | |
}, | |
"execution_count": 49, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"gs_clf.cv_results_" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"The *cv_results_* parameter can be easily imported into pandas as a DataFrame for further inspection." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.1" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment