Last active
November 13, 2020 22:14
-
-
Save ZaxR/8e36b7e4ddc761003097dd723ff83a72 to your computer and use it in GitHub Desktop.
Advanced sklearn Pipeline / GridSearchCV Workflow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
from collections import Counter | |
import numpy as np | |
import pandas as pd | |
from google.cloud import storage | |
from nltk.corpus import stopwords | |
from nltk.tokenize import word_tokenize | |
from nltk.stem.porter import PorterStemmer | |
from nltk import pos_tag # do this on token and get part of speech | |
from sklearn.base import BaseEstimator | |
from sklearn.ensemble import RandomForestClassifier | |
from sklearn.feature_extraction import FeatureHasher | |
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer | |
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report | |
from sklearn.model_selection import train_test_split, ShuffleSplit, GridSearchCV, cross_val_score, StratifiedShuffleSplit | |
from sklearn.pipeline import Pipeline | |
from sklearn.preprocessing import StandardScaler, LabelEncoder | |
from sklearn.tree import DecisionTreeClassifier | |
nltk.download('punkt') | |
nltk.download('stopwords') | |
# nltk check alphanumeric method for tokens | |
def keep_alphanumeric(ser, regex=r'[^A-Za-z]+'): | |
"""Keep only alphanumeric characters in a pd.Series.""" | |
p = re.compile(regex) | |
return pd.Series([p.sub(' ', x) for x in ser.tolist()]) | |
def tokenize(ser): | |
"""Tokenize a pd.Series.""" | |
return ser.apply(word_tokenize) | |
def remove_stopwords(ser, stopwords=None): | |
"""Remove any stopwords from ps.Series given a stopwords set.""" | |
if stopwords is None: | |
stopwords = nltk.corpus.stopwords.words('english') | |
return ser.apply(lambda x: [word for word in x if word not in stopwords]) | |
def remove_len_n_words(ser, n): | |
"""Remove any words of length n or shorter from a pd.Series.""" | |
return ser.apply(lambda x: [word for word in x if len(word) > n]) | |
def lemmatize(ser): | |
"""Lemmatize a pd.Series.""" | |
lemmatizer = PorterStemmer() | |
return ser.apply(lambda x: [lemmatizer.stem(word) for word in x]) | |
# def filter_out_part_of_speech(ser, pos_list): | |
# """Removes specfied parts of speech from pd.Series.""" | |
# return ser.apply(lambda x: [token for token in x if pos_tag(token) not in pos_list]) | |
def sklearn_ingest(ser): | |
"""Reformats series so that it can be ingested by sklearn models.""" | |
return [' '.join(x) for x in ser] | |
def hasher_ingest(ser): | |
"""Convert a pd.Series into a dictionary that can be ingested by a sklearn FeatureHasher.""" | |
return ser.apply(lambda x: Counter(x)) | |
def clean_text_series(ser, inplace=False): | |
"""Preprocesses a pd.Series with string values to enable it to be fed into an sklearn Pipeline.""" | |
_ser = ser.copy() if not inplace else ser | |
_ser = _ser.fillna("") | |
_ser = keep_alphanumeric(_ser) | |
_ser = tokenize(_ser) | |
_ser = remove_stopwords(_ser) | |
_ser = lemmatize(_ser) | |
return _ser | |
class DummyEstimator(BaseEstimator): | |
def fit(self): | |
pass | |
def score(self): | |
pass | |
pipe = Pipeline([('len_n', FunctionTransformer(remove_len_n_words, validate=False)), | |
('stringify_tok', FunctionTransformer(sklearn_ingest, validate=False)), | |
('count', CountVectorizer()), | |
('tfidf', TfidfVectorizer()), | |
('clf', DummyEstimator())]) | |
# parameters to be tested for all models | |
base_params = [{'len_n__kw_args': ({"n": 0}, {"n": 1}, {"n": 2}, {"n": 3}), | |
'count__ngram_range': [(1, 1), (1, 2)], | |
'count__max_features': [700, 1000, 1500], | |
'tfidf': [None]}, | |
{'len_n__kw_args': ({"n": 0}, {"n": 1}, {"n": 2}, {"n": 3}), | |
'count__ngram_range': [(1, 1), (1, 2)], | |
'count__max_features': [700, 1000, 1500], | |
'count': [None]}] | |
# parameters specific to the model | |
model_params = [{'clf': [DecisionTreeClassifier()], 'clf__criterion': ['gini', 'entropy']}, | |
{'clf': [RandomForestClassifier()], 'clf__criterion': ['gini', 'entropy']}] | |
# model params will overwrite base parameters if there is a conflicting key | |
params = [{**base, **d} for base in base_params for d in model_params] | |
n_jobs = -1 | |
cv = ShuffleSplit(n_splits=3, random_state=random_state) | |
grid = GridSearchCV(estimator=pipe, | |
param_grid=params, | |
scoring='accuracy', | |
cv=cv, | |
verbose=1, | |
n_jobs=n_jobs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Useful resources: